Skip to content

Commit f08b802

Browse files
feat: add support for OneTrainer BFL Flux LoRA format (#8984)
* feat: add support for OneTrainer BFL Flux LoRA format Newer versions of OneTrainer export Flux LoRAs using BFL internal key names (double_blocks, single_blocks, img_attn, etc.) with a 'transformer.' prefix and split QKV projections (qkv.0/1/2, linear1.0/1/2/3). This format was not recognized by any existing detector. Add detection and conversion for this format, merging split QKV and linear1 layers into MergedLayerPatch instances for the fused BFL model. * chore ruff
1 parent ae42182 commit f08b802

4 files changed

Lines changed: 180 additions & 0 deletions

File tree

invokeai/backend/model_manager/load/model_loaders/lora.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,10 @@
4444
is_state_dict_likely_in_flux_kohya_format,
4545
lora_model_from_flux_kohya_state_dict,
4646
)
47+
from invokeai.backend.patches.lora_conversions.flux_onetrainer_bfl_lora_conversion_utils import (
48+
is_state_dict_likely_in_flux_onetrainer_bfl_format,
49+
lora_model_from_flux_onetrainer_bfl_state_dict,
50+
)
4751
from invokeai.backend.patches.lora_conversions.flux_onetrainer_lora_conversion_utils import (
4852
is_state_dict_likely_in_flux_onetrainer_format,
4953
lora_model_from_flux_onetrainer_state_dict,
@@ -128,6 +132,8 @@ def _load_model(
128132
model = lora_model_from_flux_diffusers_state_dict(state_dict=state_dict, alpha=None)
129133
elif is_state_dict_likely_in_flux_kohya_format(state_dict=state_dict):
130134
model = lora_model_from_flux_kohya_state_dict(state_dict=state_dict)
135+
elif is_state_dict_likely_in_flux_onetrainer_bfl_format(state_dict=state_dict):
136+
model = lora_model_from_flux_onetrainer_bfl_state_dict(state_dict=state_dict)
131137
elif is_state_dict_likely_in_flux_onetrainer_format(state_dict=state_dict):
132138
model = lora_model_from_flux_onetrainer_state_dict(state_dict=state_dict)
133139
elif is_state_dict_likely_flux_control(state_dict=state_dict):

invokeai/backend/model_manager/taxonomy.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -210,6 +210,7 @@ class FluxLoRAFormat(str, Enum):
210210
AIToolkit = "flux.aitoolkit"
211211
XLabs = "flux.xlabs"
212212
BflPeft = "flux.bfl_peft"
213+
OneTrainerBfl = "flux.onetrainer_bfl"
213214

214215

215216
AnyVariant: TypeAlias = Union[
Lines changed: 168 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,168 @@
1+
"""Utilities for detecting and converting FLUX LoRAs in OneTrainer BFL format.
2+
3+
This format is produced by newer versions of OneTrainer and uses BFL internal key names
4+
(double_blocks, single_blocks, img_attn, etc.) with a 'transformer.' prefix and
5+
InvokeAI-native LoRA suffixes (lora_down.weight, lora_up.weight, alpha).
6+
7+
Unlike the standard BFL PEFT format (which uses 'diffusion_model.' prefix and lora_A/lora_B),
8+
this format also has split QKV projections:
9+
- double_blocks.{i}.img_attn.qkv.{0,1,2} (Q, K, V separate)
10+
- double_blocks.{i}.txt_attn.qkv.{0,1,2} (Q, K, V separate)
11+
- single_blocks.{i}.linear1.{0,1,2,3} (Q, K, V, MLP separate)
12+
13+
Example keys:
14+
transformer.double_blocks.0.img_attn.qkv.0.lora_down.weight
15+
transformer.double_blocks.0.img_attn.qkv.0.lora_up.weight
16+
transformer.double_blocks.0.img_attn.qkv.0.alpha
17+
transformer.single_blocks.0.linear1.3.lora_down.weight
18+
transformer.double_blocks.0.img_mlp.0.lora_down.weight
19+
"""
20+
21+
import re
22+
from typing import Any, Dict
23+
24+
import torch
25+
26+
from invokeai.backend.patches.layers.base_layer_patch import BaseLayerPatch
27+
from invokeai.backend.patches.layers.merged_layer_patch import MergedLayerPatch, Range
28+
from invokeai.backend.patches.layers.utils import any_lora_layer_from_state_dict
29+
from invokeai.backend.patches.lora_conversions.flux_lora_constants import FLUX_LORA_TRANSFORMER_PREFIX
30+
from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
31+
32+
_TRANSFORMER_PREFIX = "transformer."
33+
34+
# Valid LoRA weight suffixes in this format.
35+
_LORA_SUFFIXES = ("lora_down.weight", "lora_up.weight", "alpha")
36+
37+
# Regex to detect split QKV keys in double blocks: e.g. "double_blocks.0.img_attn.qkv.1"
38+
_SPLIT_QKV_RE = re.compile(r"^(double_blocks\.\d+\.(img_attn|txt_attn)\.qkv)\.\d+$")
39+
40+
# Regex to detect split linear1 keys in single blocks: e.g. "single_blocks.0.linear1.2"
41+
_SPLIT_LINEAR1_RE = re.compile(r"^(single_blocks\.\d+\.linear1)\.\d+$")
42+
43+
44+
def is_state_dict_likely_in_flux_onetrainer_bfl_format(
45+
state_dict: dict[str | int, Any],
46+
metadata: dict[str, Any] | None = None,
47+
) -> bool:
48+
"""Checks if the provided state dict is likely in the OneTrainer BFL FLUX LoRA format.
49+
50+
This format uses BFL internal key names with 'transformer.' prefix and split QKV projections.
51+
"""
52+
str_keys = [k for k in state_dict.keys() if isinstance(k, str)]
53+
if not str_keys:
54+
return False
55+
56+
# All keys must start with 'transformer.'
57+
if not all(k.startswith(_TRANSFORMER_PREFIX) for k in str_keys):
58+
return False
59+
60+
# All keys must end with recognized LoRA suffixes.
61+
if not all(k.endswith(_LORA_SUFFIXES) for k in str_keys):
62+
return False
63+
64+
# Must have BFL block structure (double_blocks or single_blocks) under transformer prefix.
65+
has_bfl_blocks = any(
66+
k.startswith("transformer.double_blocks.") or k.startswith("transformer.single_blocks.") for k in str_keys
67+
)
68+
if not has_bfl_blocks:
69+
return False
70+
71+
# Must have split QKV pattern (qkv.0, qkv.1, qkv.2) to distinguish from other formats
72+
# that might use transformer. prefix in the future.
73+
has_split_qkv = any(".qkv.0." in k or ".qkv.1." in k or ".qkv.2." in k or ".linear1.0." in k for k in str_keys)
74+
if not has_split_qkv:
75+
return False
76+
77+
return True
78+
79+
80+
def _split_key(key: str) -> tuple[str, str]:
81+
"""Split a key into (layer_name, weight_suffix).
82+
83+
Handles:
84+
- 2-component suffixes ending with '.weight': e.g., 'lora_down.weight' → split at 2nd-to-last dot
85+
- 1-component suffixes: e.g., 'alpha' → split at last dot
86+
"""
87+
if key.endswith(".weight"):
88+
parts = key.rsplit(".", maxsplit=2)
89+
return parts[0], f"{parts[1]}.{parts[2]}"
90+
else:
91+
parts = key.rsplit(".", maxsplit=1)
92+
return parts[0], parts[1]
93+
94+
95+
def lora_model_from_flux_onetrainer_bfl_state_dict(state_dict: Dict[str, torch.Tensor]) -> ModelPatchRaw:
96+
"""Convert a OneTrainer BFL format FLUX LoRA state dict to a ModelPatchRaw.
97+
98+
Strips the 'transformer.' prefix, groups by layer, and merges split QKV/linear1
99+
layers into MergedLayerPatch instances.
100+
"""
101+
# Step 1: Strip prefix and group by layer name.
102+
grouped_state_dict: dict[str, dict[str, torch.Tensor]] = {}
103+
for key, value in state_dict.items():
104+
if not isinstance(key, str):
105+
continue
106+
107+
# Strip 'transformer.' prefix.
108+
key = key[len(_TRANSFORMER_PREFIX) :]
109+
110+
layer_name, suffix = _split_key(key)
111+
112+
if layer_name not in grouped_state_dict:
113+
grouped_state_dict[layer_name] = {}
114+
grouped_state_dict[layer_name][suffix] = value
115+
116+
# Step 2: Build LoRA layers, merging split QKV and linear1.
117+
layers: dict[str, BaseLayerPatch] = {}
118+
119+
# Identify which layers need merging.
120+
merge_groups: dict[str, list[str]] = {}
121+
standalone_keys: list[str] = []
122+
123+
for layer_key in grouped_state_dict:
124+
qkv_match = _SPLIT_QKV_RE.match(layer_key)
125+
linear1_match = _SPLIT_LINEAR1_RE.match(layer_key)
126+
127+
if qkv_match:
128+
parent = qkv_match.group(1)
129+
if parent not in merge_groups:
130+
merge_groups[parent] = []
131+
merge_groups[parent].append(layer_key)
132+
elif linear1_match:
133+
parent = linear1_match.group(1)
134+
if parent not in merge_groups:
135+
merge_groups[parent] = []
136+
merge_groups[parent].append(layer_key)
137+
else:
138+
standalone_keys.append(layer_key)
139+
140+
# Process standalone layers.
141+
for layer_key in standalone_keys:
142+
layer_sd = grouped_state_dict[layer_key]
143+
layers[f"{FLUX_LORA_TRANSFORMER_PREFIX}{layer_key}"] = any_lora_layer_from_state_dict(layer_sd)
144+
145+
# Process merged layers.
146+
for parent_key, sub_keys in merge_groups.items():
147+
# Sort by the numeric index at the end (e.g., qkv.0, qkv.1, qkv.2).
148+
sub_keys.sort(key=lambda k: int(k.rsplit(".", maxsplit=1)[1]))
149+
150+
sub_layers: list[BaseLayerPatch] = []
151+
sub_ranges: list[Range] = []
152+
dim_0_offset = 0
153+
154+
for sub_key in sub_keys:
155+
layer_sd = grouped_state_dict[sub_key]
156+
sub_layer = any_lora_layer_from_state_dict(layer_sd)
157+
158+
# Determine the output dimension from the up weight shape.
159+
up_weight = layer_sd["lora_up.weight"]
160+
out_dim = up_weight.shape[0]
161+
162+
sub_layers.append(sub_layer)
163+
sub_ranges.append(Range(dim_0_offset, dim_0_offset + out_dim))
164+
dim_0_offset += out_dim
165+
166+
layers[f"{FLUX_LORA_TRANSFORMER_PREFIX}{parent_key}"] = MergedLayerPatch(sub_layers, sub_ranges)
167+
168+
return ModelPatchRaw(layers=layers)

invokeai/backend/patches/lora_conversions/formats.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,9 @@
1414
from invokeai.backend.patches.lora_conversions.flux_kohya_lora_conversion_utils import (
1515
is_state_dict_likely_in_flux_kohya_format,
1616
)
17+
from invokeai.backend.patches.lora_conversions.flux_onetrainer_bfl_lora_conversion_utils import (
18+
is_state_dict_likely_in_flux_onetrainer_bfl_format,
19+
)
1720
from invokeai.backend.patches.lora_conversions.flux_onetrainer_lora_conversion_utils import (
1821
is_state_dict_likely_in_flux_onetrainer_format,
1922
)
@@ -28,6 +31,8 @@ def flux_format_from_state_dict(
2831
) -> FluxLoRAFormat | None:
2932
if is_state_dict_likely_in_flux_kohya_format(state_dict):
3033
return FluxLoRAFormat.Kohya
34+
elif is_state_dict_likely_in_flux_onetrainer_bfl_format(state_dict, metadata):
35+
return FluxLoRAFormat.OneTrainerBfl
3136
elif is_state_dict_likely_in_flux_onetrainer_format(state_dict):
3237
return FluxLoRAFormat.OneTrainer
3338
elif is_state_dict_likely_in_flux_diffusers_format(state_dict):

0 commit comments

Comments
 (0)