Skip to content

Commit 388a539

Browse files
committed
Merge branch 'main' into exclude-modules-lora
2 parents 5ecf05a + 21543de commit 388a539

14 files changed

Lines changed: 3846 additions & 2 deletions

docs/source/en/api/pipelines/flux.md

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ Flux comes in the following variants:
3939
| Canny Control (LoRA) | [`black-forest-labs/FLUX.1-Canny-dev-lora`](https://huggingface.co/black-forest-labs/FLUX.1-Canny-dev-lora) |
4040
| Depth Control (LoRA) | [`black-forest-labs/FLUX.1-Depth-dev-lora`](https://huggingface.co/black-forest-labs/FLUX.1-Depth-dev-lora) |
4141
| Redux (Adapter) | [`black-forest-labs/FLUX.1-Redux-dev`](https://huggingface.co/black-forest-labs/FLUX.1-Redux-dev) |
42+
| Kontext | [`black-forest-labs/FLUX.1-kontext`](https://huggingface.co/black-forest-labs/FLUX.1-Kontext-dev) |
4243

4344
All checkpoints have different usage which we detail below.
4445

@@ -273,6 +274,46 @@ images = pipe(
273274
images[0].save("flux-redux.png")
274275
```
275276

277+
### Kontext
278+
279+
Flux Kontext is a model that allows in-context control of the image generation process, allowing for editing, refinement, relighting, style transfer, character customization, and more.
280+
281+
```python
282+
import torch
283+
from diffusers import FluxKontextPipeline
284+
from diffusers.utils import load_image
285+
286+
pipe = FluxKontextPipeline.from_pretrained(
287+
"black-forest-labs/FLUX.1-Kontext-dev", torch_dtype=torch.bfloat16
288+
)
289+
pipe.to("cuda")
290+
291+
image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/yarn-art-pikachu.png").convert("RGB")
292+
prompt = "Make Pikachu hold a sign that says 'Black Forest Labs is awesome', yarn art style, detailed, vibrant colors"
293+
image = pipe(
294+
image=image,
295+
prompt=prompt,
296+
guidance_scale=2.5,
297+
generator=torch.Generator().manual_seed(42),
298+
).images[0]
299+
image.save("flux-kontext.png")
300+
```
301+
302+
Flux Kontext comes with an integrity safety checker, which should be run after the image generation step. To run the safety checker, install the official repository from [black-forest-labs/flux](https://github.com/black-forest-labs/flux) and add the following code:
303+
304+
```python
305+
from flux.content_filters import PixtralContentFilter
306+
307+
# ... pipeline invocation to generate images
308+
309+
integrity_checker = PixtralContentFilter(torch.device("cuda"))
310+
image_ = np.array(image) / 255.0
311+
image_ = 2 * image_ - 1
312+
image_ = torch.from_numpy(image_).to("cuda", dtype=torch.float32).unsqueeze(0).permute(0, 3, 1, 2)
313+
if integrity_checker.test_image(image_):
314+
raise ValueError("Your image has been flagged. Choose another prompt/image or try again.")
315+
```
316+
276317
## Combining Flux Turbo LoRAs with Flux Control, Fill, and Redux
277318

278319
We can combine Flux Turbo LoRAs with Flux Control and other pipelines like Fill and Redux to enable few-steps' inference. The example below shows how to do that for Flux Control LoRA for depth and turbo LoRA from [`ByteDance/Hyper-SD`](https://hf.co/ByteDance/Hyper-SD).

examples/dreambooth/README_flux.md

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -260,5 +260,51 @@ to enable `latent_caching` simply pass `--cache_latents`.
260260
By default, trained transformer layers are saved in the precision dtype in which training was performed. E.g. when training in mixed precision is enabled with `--mixed_precision="bf16"`, final finetuned layers will be saved in `torch.bfloat16` as well.
261261
This reduces memory requirements significantly w/o a significant quality loss. Note that if you do wish to save the final layers in float32 at the expanse of more memory usage, you can do so by passing `--upcast_before_saving`.
262262

263+
## Training Kontext
264+
265+
[Kontext](https://bfl.ai/announcements/flux-1-kontext) lets us perform image editing as well as image generation. Even though it can accept both image and text as inputs, one can use it for text-to-image (T2I) generation, too. We
266+
provide a simple script for LoRA fine-tuning Kontext in [train_dreambooth_lora_flux_kontext.py](./train_dreambooth_lora_flux_kontext.py) for T2I. The optimizations discussed above apply this script, too.
267+
268+
Make sure to follow the [instructions to set up your environment](#running-locally-with-pytorch) before proceeding to the rest of the section.
269+
270+
Below is an example training command:
271+
272+
```bash
273+
accelerate launch train_dreambooth_lora_flux_kontext.py \
274+
--pretrained_model_name_or_path=black-forest-labs/FLUX.1-Kontext-dev \
275+
--instance_data_dir="dog" \
276+
--output_dir="kontext-dog" \
277+
--mixed_precision="bf16" \
278+
--instance_prompt="a photo of sks dog" \
279+
--resolution=1024 \
280+
--train_batch_size=1 \
281+
--guidance_scale=1 \
282+
--gradient_accumulation_steps=4 \
283+
--gradient_checkpointing \
284+
--optimizer="adamw" \
285+
--use_8bit_adam \
286+
--cache_latents \
287+
--learning_rate=1e-4 \
288+
--lr_scheduler="constant" \
289+
--lr_warmup_steps=0 \
290+
--max_train_steps=500 \
291+
--seed="0"
292+
```
293+
294+
Fine-tuning Kontext on the T2I task can be useful when working with specific styles/subjects where it may not
295+
perform as expected.
296+
297+
### Misc notes
298+
299+
* By default, we use `mode` as the value of `--vae_encode_mode` argument. This is because Kontext uses `mode()` of the distribution predicted by the VAE instead of sampling from it.
300+
### Aspect Ratio Bucketing
301+
we've added aspect ratio bucketing support which allows training on images with different aspect ratios without cropping them to a single square resolution. This technique helps preserve the original composition of training images and can improve training efficiency.
302+
303+
To enable aspect ratio bucketing, pass `--aspect_ratio_buckets` argument with a semicolon-separated list of height,width pairs, such as:
304+
305+
`--aspect_ratio_buckets="672,1568;688,1504;720,1456;752,1392;800,1328;832,1248;880,1184;944,1104;1024,1024;1104,944;1184,880;1248,832;1328,800;1392,752;1456,720;1504,688;1568,672"
306+
`
307+
Since Flux Kontext finetuning is still an experimental phase, we encourage you to explore different settings and share your insights! 🤗
308+
263309
## Other notes
264310
Thanks to `bghira` and `ostris` for their help with reviewing & insight sharing ♥️
Lines changed: 281 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,281 @@
1+
# coding=utf-8
2+
# Copyright 2025 HuggingFace Inc.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
import json
17+
import logging
18+
import os
19+
import sys
20+
import tempfile
21+
22+
import safetensors
23+
24+
from diffusers.loaders.lora_base import LORA_ADAPTER_METADATA_KEY
25+
26+
27+
sys.path.append("..")
28+
from test_examples_utils import ExamplesTestsAccelerate, run_command # noqa: E402
29+
30+
31+
logging.basicConfig(level=logging.DEBUG)
32+
33+
logger = logging.getLogger()
34+
stream_handler = logging.StreamHandler(sys.stdout)
35+
logger.addHandler(stream_handler)
36+
37+
38+
class DreamBoothLoRAFluxKontext(ExamplesTestsAccelerate):
39+
instance_data_dir = "docs/source/en/imgs"
40+
instance_prompt = "photo"
41+
pretrained_model_name_or_path = "hf-internal-testing/tiny-flux-kontext-pipe"
42+
script_path = "examples/dreambooth/train_dreambooth_lora_flux_kontext.py"
43+
transformer_layer_type = "single_transformer_blocks.0.attn.to_k"
44+
45+
def test_dreambooth_lora_flux_kontext(self):
46+
with tempfile.TemporaryDirectory() as tmpdir:
47+
test_args = f"""
48+
{self.script_path}
49+
--pretrained_model_name_or_path {self.pretrained_model_name_or_path}
50+
--instance_data_dir {self.instance_data_dir}
51+
--instance_prompt {self.instance_prompt}
52+
--resolution 64
53+
--train_batch_size 1
54+
--gradient_accumulation_steps 1
55+
--max_train_steps 2
56+
--learning_rate 5.0e-04
57+
--scale_lr
58+
--lr_scheduler constant
59+
--lr_warmup_steps 0
60+
--output_dir {tmpdir}
61+
""".split()
62+
63+
run_command(self._launch_args + test_args)
64+
# save_pretrained smoke test
65+
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
66+
67+
# make sure the state_dict has the correct naming in the parameters.
68+
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
69+
is_lora = all("lora" in k for k in lora_state_dict.keys())
70+
self.assertTrue(is_lora)
71+
72+
# when not training the text encoder, all the parameters in the state dict should start
73+
# with `"transformer"` in their names.
74+
starts_with_transformer = all(key.startswith("transformer") for key in lora_state_dict.keys())
75+
self.assertTrue(starts_with_transformer)
76+
77+
def test_dreambooth_lora_text_encoder_flux_kontext(self):
78+
with tempfile.TemporaryDirectory() as tmpdir:
79+
test_args = f"""
80+
{self.script_path}
81+
--pretrained_model_name_or_path {self.pretrained_model_name_or_path}
82+
--instance_data_dir {self.instance_data_dir}
83+
--instance_prompt {self.instance_prompt}
84+
--resolution 64
85+
--train_batch_size 1
86+
--train_text_encoder
87+
--gradient_accumulation_steps 1
88+
--max_train_steps 2
89+
--learning_rate 5.0e-04
90+
--scale_lr
91+
--lr_scheduler constant
92+
--lr_warmup_steps 0
93+
--output_dir {tmpdir}
94+
""".split()
95+
96+
run_command(self._launch_args + test_args)
97+
# save_pretrained smoke test
98+
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
99+
100+
# make sure the state_dict has the correct naming in the parameters.
101+
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
102+
is_lora = all("lora" in k for k in lora_state_dict.keys())
103+
self.assertTrue(is_lora)
104+
105+
starts_with_expected_prefix = all(
106+
(key.startswith("transformer") or key.startswith("text_encoder")) for key in lora_state_dict.keys()
107+
)
108+
self.assertTrue(starts_with_expected_prefix)
109+
110+
def test_dreambooth_lora_latent_caching(self):
111+
with tempfile.TemporaryDirectory() as tmpdir:
112+
test_args = f"""
113+
{self.script_path}
114+
--pretrained_model_name_or_path {self.pretrained_model_name_or_path}
115+
--instance_data_dir {self.instance_data_dir}
116+
--instance_prompt {self.instance_prompt}
117+
--resolution 64
118+
--train_batch_size 1
119+
--gradient_accumulation_steps 1
120+
--max_train_steps 2
121+
--cache_latents
122+
--learning_rate 5.0e-04
123+
--scale_lr
124+
--lr_scheduler constant
125+
--lr_warmup_steps 0
126+
--output_dir {tmpdir}
127+
""".split()
128+
129+
run_command(self._launch_args + test_args)
130+
# save_pretrained smoke test
131+
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
132+
133+
# make sure the state_dict has the correct naming in the parameters.
134+
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
135+
is_lora = all("lora" in k for k in lora_state_dict.keys())
136+
self.assertTrue(is_lora)
137+
138+
# when not training the text encoder, all the parameters in the state dict should start
139+
# with `"transformer"` in their names.
140+
starts_with_transformer = all(key.startswith("transformer") for key in lora_state_dict.keys())
141+
self.assertTrue(starts_with_transformer)
142+
143+
def test_dreambooth_lora_layers(self):
144+
with tempfile.TemporaryDirectory() as tmpdir:
145+
test_args = f"""
146+
{self.script_path}
147+
--pretrained_model_name_or_path {self.pretrained_model_name_or_path}
148+
--instance_data_dir {self.instance_data_dir}
149+
--instance_prompt {self.instance_prompt}
150+
--resolution 64
151+
--train_batch_size 1
152+
--gradient_accumulation_steps 1
153+
--max_train_steps 2
154+
--cache_latents
155+
--learning_rate 5.0e-04
156+
--scale_lr
157+
--lora_layers {self.transformer_layer_type}
158+
--lr_scheduler constant
159+
--lr_warmup_steps 0
160+
--output_dir {tmpdir}
161+
""".split()
162+
163+
run_command(self._launch_args + test_args)
164+
# save_pretrained smoke test
165+
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
166+
167+
# make sure the state_dict has the correct naming in the parameters.
168+
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
169+
is_lora = all("lora" in k for k in lora_state_dict.keys())
170+
self.assertTrue(is_lora)
171+
172+
# when not training the text encoder, all the parameters in the state dict should start
173+
# with `"transformer"` in their names. In this test, we only params of
174+
# transformer.single_transformer_blocks.0.attn.to_k should be in the state dict
175+
starts_with_transformer = all(
176+
key.startswith("transformer.single_transformer_blocks.0.attn.to_k") for key in lora_state_dict.keys()
177+
)
178+
self.assertTrue(starts_with_transformer)
179+
180+
def test_dreambooth_lora_flux_kontext_checkpointing_checkpoints_total_limit(self):
181+
with tempfile.TemporaryDirectory() as tmpdir:
182+
test_args = f"""
183+
{self.script_path}
184+
--pretrained_model_name_or_path={self.pretrained_model_name_or_path}
185+
--instance_data_dir={self.instance_data_dir}
186+
--output_dir={tmpdir}
187+
--instance_prompt={self.instance_prompt}
188+
--resolution=64
189+
--train_batch_size=1
190+
--gradient_accumulation_steps=1
191+
--max_train_steps=6
192+
--checkpoints_total_limit=2
193+
--checkpointing_steps=2
194+
""".split()
195+
196+
run_command(self._launch_args + test_args)
197+
198+
self.assertEqual(
199+
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
200+
{"checkpoint-4", "checkpoint-6"},
201+
)
202+
203+
def test_dreambooth_lora_flux_kontext_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints(self):
204+
with tempfile.TemporaryDirectory() as tmpdir:
205+
test_args = f"""
206+
{self.script_path}
207+
--pretrained_model_name_or_path={self.pretrained_model_name_or_path}
208+
--instance_data_dir={self.instance_data_dir}
209+
--output_dir={tmpdir}
210+
--instance_prompt={self.instance_prompt}
211+
--resolution=64
212+
--train_batch_size=1
213+
--gradient_accumulation_steps=1
214+
--max_train_steps=4
215+
--checkpointing_steps=2
216+
""".split()
217+
218+
run_command(self._launch_args + test_args)
219+
220+
self.assertEqual({x for x in os.listdir(tmpdir) if "checkpoint" in x}, {"checkpoint-2", "checkpoint-4"})
221+
222+
resume_run_args = f"""
223+
{self.script_path}
224+
--pretrained_model_name_or_path={self.pretrained_model_name_or_path}
225+
--instance_data_dir={self.instance_data_dir}
226+
--output_dir={tmpdir}
227+
--instance_prompt={self.instance_prompt}
228+
--resolution=64
229+
--train_batch_size=1
230+
--gradient_accumulation_steps=1
231+
--max_train_steps=8
232+
--checkpointing_steps=2
233+
--resume_from_checkpoint=checkpoint-4
234+
--checkpoints_total_limit=2
235+
""".split()
236+
237+
run_command(self._launch_args + resume_run_args)
238+
239+
self.assertEqual({x for x in os.listdir(tmpdir) if "checkpoint" in x}, {"checkpoint-6", "checkpoint-8"})
240+
241+
def test_dreambooth_lora_with_metadata(self):
242+
# Use a `lora_alpha` that is different from `rank`.
243+
lora_alpha = 8
244+
rank = 4
245+
with tempfile.TemporaryDirectory() as tmpdir:
246+
test_args = f"""
247+
{self.script_path}
248+
--pretrained_model_name_or_path {self.pretrained_model_name_or_path}
249+
--instance_data_dir {self.instance_data_dir}
250+
--instance_prompt {self.instance_prompt}
251+
--resolution 64
252+
--train_batch_size 1
253+
--gradient_accumulation_steps 1
254+
--max_train_steps 2
255+
--lora_alpha={lora_alpha}
256+
--rank={rank}
257+
--learning_rate 5.0e-04
258+
--scale_lr
259+
--lr_scheduler constant
260+
--lr_warmup_steps 0
261+
--output_dir {tmpdir}
262+
""".split()
263+
264+
run_command(self._launch_args + test_args)
265+
# save_pretrained smoke test
266+
state_dict_file = os.path.join(tmpdir, "pytorch_lora_weights.safetensors")
267+
self.assertTrue(os.path.isfile(state_dict_file))
268+
269+
# Check if the metadata was properly serialized.
270+
with safetensors.torch.safe_open(state_dict_file, framework="pt", device="cpu") as f:
271+
metadata = f.metadata() or {}
272+
273+
metadata.pop("format", None)
274+
raw = metadata.get(LORA_ADAPTER_METADATA_KEY)
275+
if raw:
276+
raw = json.loads(raw)
277+
278+
loaded_lora_alpha = raw["transformer.lora_alpha"]
279+
self.assertTrue(loaded_lora_alpha == lora_alpha)
280+
loaded_lora_rank = raw["transformer.r"]
281+
self.assertTrue(loaded_lora_rank == rank)

0 commit comments

Comments
 (0)