Skip to content

Commit 7f2b34b

Browse files
tcaimmlinoytsaban
andauthored
Add train flux2 series lora config (#13011)
* feat(lora): support FLUX.2 single blocks + update README * add img2img config & add explanatory comments * simple modify --------- Co-authored-by: Linoy Tsaban <57615435+linoytsaban@users.noreply.github.com>
1 parent e1e7d58 commit 7f2b34b

File tree

5 files changed

+34
-9
lines changed

5 files changed

+34
-9
lines changed

examples/dreambooth/README_flux2.md

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -347,16 +347,17 @@ When LoRA was first adapted from language models to diffusion models, it was app
347347
More recently, SOTA text-to-image diffusion models replaced the Unet with a diffusion Transformer(DiT). With this change, we may also want to explore
348348
applying LoRA training onto different types of layers and blocks. To allow more flexibility and control over the targeted modules we added `--lora_layers`- in which you can specify in a comma separated string
349349
the exact modules for LoRA training. Here are some examples of target modules you can provide:
350-
- for attention only layers: `--lora_layers="attn.to_k,attn.to_q,attn.to_v,attn.to_out.0"`
351-
- to train the same modules as in the fal trainer: `--lora_layers="attn.to_k,attn.to_q,attn.to_v,attn.to_out.0,attn.add_k_proj,attn.add_q_proj,attn.add_v_proj,attn.to_add_out,ff.net.0.proj,ff.net.2,ff_context.net.0.proj,ff_context.net.2"`
352-
- to train the same modules as in ostris ai-toolkit / replicate trainer: `--lora_blocks="attn.to_k,attn.to_q,attn.to_v,attn.to_out.0,attn.add_k_proj,attn.add_q_proj,attn.add_v_proj,attn.to_add_out,ff.net.0.proj,ff.net.2,ff_context.net.0.proj,ff_context.net.2,norm1_context.linear, norm1.linear,norm.linear,proj_mlp,proj_out"`
350+
- for attention only layers: `--lora_layers="attn.to_k,attn.to_q,attn.to_v,attn.to_out.0,attn.to_qkv_mlp_proj"`
351+
- to train the same modules as in the fal trainer: `--lora_layers="attn.to_k,attn.to_q,attn.to_v,attn.to_out.0,attn.to_qkv_mlp_proj,attn.add_k_proj,attn.add_q_proj,attn.add_v_proj,attn.to_add_out,ff.linear_in,ff.linear_out,ff_context.linear_in,ff_context.linear_out"`
352+
- to train the same modules as in ostris ai-toolkit / replicate trainer: `--lora_blocks="attn.to_k,attn.to_q,attn.to_v,attn.to_out.0,attn.to_qkv_mlp_proj,attn.add_k_proj,attn.add_q_proj,attn.add_v_proj,attn.to_add_out,ff.linear_in,ff.linear_out,ff_context.linear_in,ff_context.linear_out,norm_out.linear,norm_out.proj_out"`
353353
> [!NOTE]
354354
> `--lora_layers` can also be used to specify which **blocks** to apply LoRA training to. To do so, simply add a block prefix to each layer in the comma separated string:
355355
> **single DiT blocks**: to target the ith single transformer block, add the prefix `single_transformer_blocks.i`, e.g. - `single_transformer_blocks.i.attn.to_k`
356-
> **MMDiT blocks**: to target the ith MMDiT block, add the prefix `transformer_blocks.i`, e.g. - `transformer_blocks.i.attn.to_k`
356+
> **MMDiT blocks**: to target the ith MMDiT block, add the prefix `transformer_blocks.i`, e.g. - `transformer_blocks.i.attn.to_k`
357357
> [!NOTE]
358358
> keep in mind that while training more layers can improve quality and expressiveness, it also increases the size of the output LoRA weights.
359-
359+
> [!NOTE]
360+
In FLUX2, the q, k, and v projections are fused into a single linear layer named attn.to_qkv_mlp_proj within the single transformer block. Also, the attention output is just attn.to_out, not attn.to_out.0 — it’s no longer a ModuleList like in transformer block.
360361

361362
## Training Image-to-Image
362363

examples/dreambooth/train_dreambooth_lora_flux2.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1256,7 +1256,13 @@ def main(args):
12561256
if args.lora_layers is not None:
12571257
target_modules = [layer.strip() for layer in args.lora_layers.split(",")]
12581258
else:
1259-
target_modules = ["to_k", "to_q", "to_v", "to_out.0"]
1259+
# target_modules = ["to_k", "to_q", "to_v", "to_out.0"] # just train transformer_blocks
1260+
1261+
# train transformer_blocks and single_transformer_blocks
1262+
target_modules = ["to_k", "to_q", "to_v", "to_out.0"] + [
1263+
"to_qkv_mlp_proj",
1264+
*[f"single_transformer_blocks.{i}.attn.to_out" for i in range(48)],
1265+
]
12601266

12611267
# now we will add new LoRA weights the transformer layers
12621268
transformer_lora_config = LoraConfig(

examples/dreambooth/train_dreambooth_lora_flux2_img2img.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1206,7 +1206,13 @@ def main(args):
12061206
if args.lora_layers is not None:
12071207
target_modules = [layer.strip() for layer in args.lora_layers.split(",")]
12081208
else:
1209-
target_modules = ["to_k", "to_q", "to_v", "to_out.0"]
1209+
# target_modules = ["to_k", "to_q", "to_v", "to_out.0"] # just train transformer_blocks
1210+
1211+
# train transformer_blocks and single_transformer_blocks
1212+
target_modules = ["to_k", "to_q", "to_v", "to_out.0"] + [
1213+
"to_qkv_mlp_proj",
1214+
*[f"single_transformer_blocks.{i}.attn.to_out" for i in range(48)],
1215+
]
12101216

12111217
# now we will add new LoRA weights the transformer layers
12121218
transformer_lora_config = LoraConfig(

examples/dreambooth/train_dreambooth_lora_flux2_klein.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1249,7 +1249,13 @@ def main(args):
12491249
if args.lora_layers is not None:
12501250
target_modules = [layer.strip() for layer in args.lora_layers.split(",")]
12511251
else:
1252-
target_modules = ["to_k", "to_q", "to_v", "to_out.0"]
1252+
# target_modules = ["to_k", "to_q", "to_v", "to_out.0"] # just train transformer_blocks
1253+
1254+
# train transformer_blocks and single_transformer_blocks
1255+
target_modules = ["to_k", "to_q", "to_v", "to_out.0"] + [
1256+
"to_qkv_mlp_proj",
1257+
*[f"single_transformer_blocks.{i}.attn.to_out" for i in range(24)],
1258+
]
12531259

12541260
# now we will add new LoRA weights the transformer layers
12551261
transformer_lora_config = LoraConfig(

examples/dreambooth/train_dreambooth_lora_flux2_klein_img2img.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1200,7 +1200,13 @@ def main(args):
12001200
if args.lora_layers is not None:
12011201
target_modules = [layer.strip() for layer in args.lora_layers.split(",")]
12021202
else:
1203-
target_modules = ["to_k", "to_q", "to_v", "to_out.0"]
1203+
# target_modules = ["to_k", "to_q", "to_v", "to_out.0"] # just train transformer_blocks
1204+
1205+
# train transformer_blocks and single_transformer_blocks
1206+
target_modules = ["to_k", "to_q", "to_v", "to_out.0"] + [
1207+
"to_qkv_mlp_proj",
1208+
*[f"single_transformer_blocks.{i}.attn.to_out" for i in range(24)],
1209+
]
12041210

12051211
# now we will add new LoRA weights the transformer layers
12061212
transformer_lora_config = LoraConfig(

0 commit comments

Comments
 (0)