|
| 1 | +"""Utilities for detecting and converting FLUX LoRAs in OneTrainer BFL format. |
| 2 | +
|
| 3 | +This format is produced by newer versions of OneTrainer and uses BFL internal key names |
| 4 | +(double_blocks, single_blocks, img_attn, etc.) with a 'transformer.' prefix and |
| 5 | +InvokeAI-native LoRA suffixes (lora_down.weight, lora_up.weight, alpha). |
| 6 | +
|
| 7 | +Unlike the standard BFL PEFT format (which uses 'diffusion_model.' prefix and lora_A/lora_B), |
| 8 | +this format also has split QKV projections: |
| 9 | + - double_blocks.{i}.img_attn.qkv.{0,1,2} (Q, K, V separate) |
| 10 | + - double_blocks.{i}.txt_attn.qkv.{0,1,2} (Q, K, V separate) |
| 11 | + - single_blocks.{i}.linear1.{0,1,2,3} (Q, K, V, MLP separate) |
| 12 | +
|
| 13 | +Example keys: |
| 14 | + transformer.double_blocks.0.img_attn.qkv.0.lora_down.weight |
| 15 | + transformer.double_blocks.0.img_attn.qkv.0.lora_up.weight |
| 16 | + transformer.double_blocks.0.img_attn.qkv.0.alpha |
| 17 | + transformer.single_blocks.0.linear1.3.lora_down.weight |
| 18 | + transformer.double_blocks.0.img_mlp.0.lora_down.weight |
| 19 | +""" |
| 20 | + |
| 21 | +import re |
| 22 | +from typing import Any, Dict |
| 23 | + |
| 24 | +import torch |
| 25 | + |
| 26 | +from invokeai.backend.patches.layers.base_layer_patch import BaseLayerPatch |
| 27 | +from invokeai.backend.patches.layers.merged_layer_patch import MergedLayerPatch, Range |
| 28 | +from invokeai.backend.patches.layers.utils import any_lora_layer_from_state_dict |
| 29 | +from invokeai.backend.patches.lora_conversions.flux_lora_constants import FLUX_LORA_TRANSFORMER_PREFIX |
| 30 | +from invokeai.backend.patches.model_patch_raw import ModelPatchRaw |
| 31 | + |
| 32 | +_TRANSFORMER_PREFIX = "transformer." |
| 33 | + |
| 34 | +# Valid LoRA weight suffixes in this format. |
| 35 | +_LORA_SUFFIXES = ("lora_down.weight", "lora_up.weight", "alpha") |
| 36 | + |
| 37 | +# Regex to detect split QKV keys in double blocks: e.g. "double_blocks.0.img_attn.qkv.1" |
| 38 | +_SPLIT_QKV_RE = re.compile(r"^(double_blocks\.\d+\.(img_attn|txt_attn)\.qkv)\.\d+$") |
| 39 | + |
| 40 | +# Regex to detect split linear1 keys in single blocks: e.g. "single_blocks.0.linear1.2" |
| 41 | +_SPLIT_LINEAR1_RE = re.compile(r"^(single_blocks\.\d+\.linear1)\.\d+$") |
| 42 | + |
| 43 | + |
| 44 | +def is_state_dict_likely_in_flux_onetrainer_bfl_format( |
| 45 | + state_dict: dict[str | int, Any], |
| 46 | + metadata: dict[str, Any] | None = None, |
| 47 | +) -> bool: |
| 48 | + """Checks if the provided state dict is likely in the OneTrainer BFL FLUX LoRA format. |
| 49 | +
|
| 50 | + This format uses BFL internal key names with 'transformer.' prefix and split QKV projections. |
| 51 | + """ |
| 52 | + str_keys = [k for k in state_dict.keys() if isinstance(k, str)] |
| 53 | + if not str_keys: |
| 54 | + return False |
| 55 | + |
| 56 | + # All keys must start with 'transformer.' |
| 57 | + if not all(k.startswith(_TRANSFORMER_PREFIX) for k in str_keys): |
| 58 | + return False |
| 59 | + |
| 60 | + # All keys must end with recognized LoRA suffixes. |
| 61 | + if not all(k.endswith(_LORA_SUFFIXES) for k in str_keys): |
| 62 | + return False |
| 63 | + |
| 64 | + # Must have BFL block structure (double_blocks or single_blocks) under transformer prefix. |
| 65 | + has_bfl_blocks = any( |
| 66 | + k.startswith("transformer.double_blocks.") or k.startswith("transformer.single_blocks.") for k in str_keys |
| 67 | + ) |
| 68 | + if not has_bfl_blocks: |
| 69 | + return False |
| 70 | + |
| 71 | + # Must have split QKV pattern (qkv.0, qkv.1, qkv.2) to distinguish from other formats |
| 72 | + # that might use transformer. prefix in the future. |
| 73 | + has_split_qkv = any(".qkv.0." in k or ".qkv.1." in k or ".qkv.2." in k or ".linear1.0." in k for k in str_keys) |
| 74 | + if not has_split_qkv: |
| 75 | + return False |
| 76 | + |
| 77 | + return True |
| 78 | + |
| 79 | + |
| 80 | +def _split_key(key: str) -> tuple[str, str]: |
| 81 | + """Split a key into (layer_name, weight_suffix). |
| 82 | +
|
| 83 | + Handles: |
| 84 | + - 2-component suffixes ending with '.weight': e.g., 'lora_down.weight' → split at 2nd-to-last dot |
| 85 | + - 1-component suffixes: e.g., 'alpha' → split at last dot |
| 86 | + """ |
| 87 | + if key.endswith(".weight"): |
| 88 | + parts = key.rsplit(".", maxsplit=2) |
| 89 | + return parts[0], f"{parts[1]}.{parts[2]}" |
| 90 | + else: |
| 91 | + parts = key.rsplit(".", maxsplit=1) |
| 92 | + return parts[0], parts[1] |
| 93 | + |
| 94 | + |
| 95 | +def lora_model_from_flux_onetrainer_bfl_state_dict(state_dict: Dict[str, torch.Tensor]) -> ModelPatchRaw: |
| 96 | + """Convert a OneTrainer BFL format FLUX LoRA state dict to a ModelPatchRaw. |
| 97 | +
|
| 98 | + Strips the 'transformer.' prefix, groups by layer, and merges split QKV/linear1 |
| 99 | + layers into MergedLayerPatch instances. |
| 100 | + """ |
| 101 | + # Step 1: Strip prefix and group by layer name. |
| 102 | + grouped_state_dict: dict[str, dict[str, torch.Tensor]] = {} |
| 103 | + for key, value in state_dict.items(): |
| 104 | + if not isinstance(key, str): |
| 105 | + continue |
| 106 | + |
| 107 | + # Strip 'transformer.' prefix. |
| 108 | + key = key[len(_TRANSFORMER_PREFIX) :] |
| 109 | + |
| 110 | + layer_name, suffix = _split_key(key) |
| 111 | + |
| 112 | + if layer_name not in grouped_state_dict: |
| 113 | + grouped_state_dict[layer_name] = {} |
| 114 | + grouped_state_dict[layer_name][suffix] = value |
| 115 | + |
| 116 | + # Step 2: Build LoRA layers, merging split QKV and linear1. |
| 117 | + layers: dict[str, BaseLayerPatch] = {} |
| 118 | + |
| 119 | + # Identify which layers need merging. |
| 120 | + merge_groups: dict[str, list[str]] = {} |
| 121 | + standalone_keys: list[str] = [] |
| 122 | + |
| 123 | + for layer_key in grouped_state_dict: |
| 124 | + qkv_match = _SPLIT_QKV_RE.match(layer_key) |
| 125 | + linear1_match = _SPLIT_LINEAR1_RE.match(layer_key) |
| 126 | + |
| 127 | + if qkv_match: |
| 128 | + parent = qkv_match.group(1) |
| 129 | + if parent not in merge_groups: |
| 130 | + merge_groups[parent] = [] |
| 131 | + merge_groups[parent].append(layer_key) |
| 132 | + elif linear1_match: |
| 133 | + parent = linear1_match.group(1) |
| 134 | + if parent not in merge_groups: |
| 135 | + merge_groups[parent] = [] |
| 136 | + merge_groups[parent].append(layer_key) |
| 137 | + else: |
| 138 | + standalone_keys.append(layer_key) |
| 139 | + |
| 140 | + # Process standalone layers. |
| 141 | + for layer_key in standalone_keys: |
| 142 | + layer_sd = grouped_state_dict[layer_key] |
| 143 | + layers[f"{FLUX_LORA_TRANSFORMER_PREFIX}{layer_key}"] = any_lora_layer_from_state_dict(layer_sd) |
| 144 | + |
| 145 | + # Process merged layers. |
| 146 | + for parent_key, sub_keys in merge_groups.items(): |
| 147 | + # Sort by the numeric index at the end (e.g., qkv.0, qkv.1, qkv.2). |
| 148 | + sub_keys.sort(key=lambda k: int(k.rsplit(".", maxsplit=1)[1])) |
| 149 | + |
| 150 | + sub_layers: list[BaseLayerPatch] = [] |
| 151 | + sub_ranges: list[Range] = [] |
| 152 | + dim_0_offset = 0 |
| 153 | + |
| 154 | + for sub_key in sub_keys: |
| 155 | + layer_sd = grouped_state_dict[sub_key] |
| 156 | + sub_layer = any_lora_layer_from_state_dict(layer_sd) |
| 157 | + |
| 158 | + # Determine the output dimension from the up weight shape. |
| 159 | + up_weight = layer_sd["lora_up.weight"] |
| 160 | + out_dim = up_weight.shape[0] |
| 161 | + |
| 162 | + sub_layers.append(sub_layer) |
| 163 | + sub_ranges.append(Range(dim_0_offset, dim_0_offset + out_dim)) |
| 164 | + dim_0_offset += out_dim |
| 165 | + |
| 166 | + layers[f"{FLUX_LORA_TRANSFORMER_PREFIX}{parent_key}"] = MergedLayerPatch(sub_layers, sub_ranges) |
| 167 | + |
| 168 | + return ModelPatchRaw(layers=layers) |
0 commit comments