Skip to content

Commit 0920939

Browse files
Add three-tier LoKR quality comparison (fuse-first, Kronecker split, SVD)
- Add fuse_qkv parameter to BFL LoKR converter for lossless fuse-first path - Thread fuse_qkv through lora_pipeline.py (lora_state_dict -> load_lora_weights) - Fuse model QKV projections before adapter injection when fuse_qkv=True - Update benchmark script with --tiers, --no-offload flags for all three paths
1 parent b5958e6 commit 0920939

3 files changed

Lines changed: 110 additions & 54 deletions

File tree

benchmark_lokr.py

Lines changed: 84 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,20 @@
1-
"""Benchmark: Lossless LoKR vs Lossy LoRA-via-SVD on Flux2 Klein 9B.
1+
"""Benchmark: Three-tier LoKR quality comparison on Flux2 Klein 9B.
2+
3+
Tier 1 - Fuse-first (lossless): Fuse model QKV, map BFL LoKR directly. Exact.
4+
Tier 2 - Kronecker split (default): Split fused QKV via Van Loan re-factorization. Slight loss.
5+
Tier 3 - SVD to LoRA (fully lossy): Convert entire LoKR to LoRA via peft.convert_to_lora.
6+
7+
Tiers 1+2 only apply to BFL-format LoKR (fused QKV). LyCORIS and diffusers-native
8+
formats already have separate Q/K/V and only run the default path.
29
3-
Generates images using both conversion paths for visual comparison.
410
Uses bf16 with CPU offload.
511
612
Usage:
713
python benchmark_lokr.py
814
python benchmark_lokr.py --lokr-path "puttmorbidly233/lora" --lokr-name "klein_snofs_v1_2.safetensors"
915
python benchmark_lokr.py --prompt "a portrait in besch art style" --ranks 32 64 128
16+
python benchmark_lokr.py --tiers 1 2 # skip SVD tier
17+
python benchmark_lokr.py --tiers 2 3 # skip fuse-first tier
1018
"""
1119

1220
import argparse
@@ -15,18 +23,22 @@
1523
import time
1624

1725
import torch
26+
1827
from diffusers import Flux2KleinPipeline
19-
from peft import convert_to_lora
28+
2029

2130
MODEL_ID = "black-forest-labs/FLUX.2-klein-9B"
2231
DEFAULT_LOKR_PATH = "gattaplayer/besch-flux2-klein-9b-lokr-lion-3e-6-bs2-ga2-v02"
2332
OUTPUT_DIR = "benchmark_output"
2433

2534

26-
def load_pipeline():
27-
"""Load Flux2 Klein 9B in bf16 with model CPU offload."""
35+
def load_pipeline(no_offload=False):
36+
"""Load Flux2 Klein 9B in bf16."""
2837
pipe = Flux2KleinPipeline.from_pretrained(MODEL_ID, torch_dtype=torch.bfloat16)
29-
pipe.enable_model_cpu_offload()
38+
if no_offload:
39+
pipe = pipe.to("cuda")
40+
else:
41+
pipe.enable_model_cpu_offload()
3042
return pipe
3143

3244

@@ -44,9 +56,34 @@ def generate(pipe, prompt, seed, num_steps=4, guidance_scale=1.0):
4456
return image
4557

4658

47-
def benchmark_lossless(pipe, prompt, seed, lokr_path, lokr_name):
48-
"""Path A: Load LoKR natively (lossless)."""
49-
print("\n=== Path A: Lossless LoKR ===")
59+
def benchmark_baseline(pipe, prompt, seed):
60+
"""Baseline: No adapter."""
61+
print("\n=== Baseline: No adapter ===")
62+
t0 = time.time()
63+
image = generate(pipe, prompt, seed)
64+
print(f" Generated in {time.time() - t0:.1f}s")
65+
return image
66+
67+
68+
def benchmark_tier1_fuse_first(pipe, prompt, seed, lokr_path, lokr_name):
69+
"""Tier 1: Fuse model QKV, then load BFL LoKR directly (lossless)."""
70+
print("\n=== Tier 1: Fuse-first LoKR (lossless) ===")
71+
t0 = time.time()
72+
kwargs = {"weight_name": lokr_name} if lokr_name else {}
73+
pipe.load_lora_weights(lokr_path, fuse_qkv=True, **kwargs)
74+
print(f" Loaded in {time.time() - t0:.1f}s")
75+
76+
t0 = time.time()
77+
image = generate(pipe, prompt, seed)
78+
print(f" Generated in {time.time() - t0:.1f}s")
79+
80+
pipe.unload_lora_weights()
81+
return image
82+
83+
84+
def benchmark_tier2_kronecker_split(pipe, prompt, seed, lokr_path, lokr_name):
85+
"""Tier 2: Split fused QKV via Kronecker re-factorization (default path)."""
86+
print("\n=== Tier 2: Kronecker split LoKR (default) ===")
5087
t0 = time.time()
5188
kwargs = {"weight_name": lokr_name} if lokr_name else {}
5289
pipe.load_lora_weights(lokr_path, **kwargs)
@@ -60,15 +97,16 @@ def benchmark_lossless(pipe, prompt, seed, lokr_path, lokr_name):
6097
return image
6198

6299

63-
def benchmark_lossy(pipe, prompt, seed, rank, lokr_path, lokr_name):
64-
"""Path B: Load LoKR, convert to LoRA via SVD (lossy)."""
65-
print(f"\n=== Path B: Lossy LoRA via SVD (rank={rank}) ===")
100+
def benchmark_tier3_svd(pipe, prompt, seed, rank, lokr_path, lokr_name):
101+
"""Tier 3: Convert LoKR to LoRA via SVD (fully lossy)."""
102+
from peft import convert_to_lora, inject_adapter_in_model, set_peft_model_state_dict
103+
104+
print(f"\n=== Tier 3: SVD to LoRA (rank={rank}) ===")
66105
t0 = time.time()
67106
kwargs = {"weight_name": lokr_name} if lokr_name else {}
68107
pipe.load_lora_weights(lokr_path, **kwargs)
69108
load_time = time.time() - t0
70109

71-
# Detect the actual adapter name assigned by peft
72110
adapter_name = next(iter(pipe.transformer.peft_config.keys()))
73111
print(f" Adapter name: {adapter_name}")
74112

@@ -77,9 +115,6 @@ def benchmark_lossy(pipe, prompt, seed, rank, lokr_path, lokr_name):
77115
convert_time = time.time() - t0
78116
print(f" Loaded LoKR in {load_time:.1f}s, converted to LoRA in {convert_time:.1f}s")
79117

80-
# Replace LoKR adapter with converted LoRA
81-
from peft import inject_adapter_in_model, set_peft_model_state_dict
82-
83118
pipe.transformer.delete_adapters(adapter_name)
84119
inject_adapter_in_model(lora_config, pipe.transformer, adapter_name=adapter_name)
85120
set_peft_model_state_dict(pipe.transformer, lora_sd, adapter_name=adapter_name)
@@ -92,24 +127,18 @@ def benchmark_lossy(pipe, prompt, seed, rank, lokr_path, lokr_name):
92127
return image
93128

94129

95-
def benchmark_baseline(pipe, prompt, seed):
96-
"""Baseline: No adapter."""
97-
print("\n=== Baseline: No adapter ===")
98-
t0 = time.time()
99-
image = generate(pipe, prompt, seed)
100-
print(f" Generated in {time.time() - t0:.1f}s")
101-
return image
102-
103-
104130
def main():
105-
parser = argparse.ArgumentParser(description="Benchmark LoKR vs LoRA-via-SVD")
131+
parser = argparse.ArgumentParser(description="Benchmark LoKR quality tiers")
106132
parser.add_argument("--prompt", default="a portrait painting in besch art style")
107133
parser.add_argument("--lokr-path", default=DEFAULT_LOKR_PATH, help="HF repo or local path to LoKR checkpoint")
108134
parser.add_argument("--lokr-name", default=None, help="Filename within HF repo (if multi-file)")
109135
parser.add_argument("--seed", type=int, default=42)
110-
parser.add_argument("--ranks", type=int, nargs="+", default=[32, 64, 128])
136+
parser.add_argument(
137+
"--tiers", type=int, nargs="+", default=[1, 2, 3], help="Tiers to run (1=fuse, 2=kronecker, 3=svd)"
138+
)
139+
parser.add_argument("--ranks", type=int, nargs="+", default=[32, 64, 128], help="SVD ranks for tier 3")
111140
parser.add_argument("--skip-baseline", action="store_true")
112-
parser.add_argument("--skip-lossy", action="store_true")
141+
parser.add_argument("--no-offload", action="store_true", help="Keep model on GPU instead of CPU offload")
113142
args = parser.parse_args()
114143

115144
os.makedirs(OUTPUT_DIR, exist_ok=True)
@@ -118,11 +147,13 @@ def main():
118147
print(f"LoKR: {args.lokr_path}" + (f" ({args.lokr_name})" if args.lokr_name else ""))
119148
print(f"Prompt: {args.prompt}")
120149
print(f"Seed: {args.seed}")
121-
if not args.skip_lossy:
122-
print(f"SVD ranks to test: {args.ranks}")
150+
print(f"Tiers: {args.tiers}")
151+
if 3 in args.tiers:
152+
print(f"SVD ranks: {args.ranks}")
123153

124-
print("\nLoading pipeline (bf16, model CPU offload)...")
125-
pipe = load_pipeline()
154+
mode = "on GPU" if args.no_offload else "with CPU offload"
155+
print(f"\nLoading pipeline (bf16, {mode})...")
156+
pipe = load_pipeline(no_offload=args.no_offload)
126157

127158
# Baseline
128159
if not args.skip_baseline:
@@ -131,28 +162,36 @@ def main():
131162
img.save(path)
132163
print(f" Saved: {path}")
133164

134-
# Path A: Lossless LoKR
135-
img = benchmark_lossless(pipe, args.prompt, args.seed, args.lokr_path, args.lokr_name)
136-
path = os.path.join(OUTPUT_DIR, "lokr_lossless.png")
137-
img.save(path)
138-
print(f" Saved: {path}")
165+
# Tier 1: Fuse-first (lossless, BFL only)
166+
if 1 in args.tiers:
167+
img = benchmark_tier1_fuse_first(pipe, args.prompt, args.seed, args.lokr_path, args.lokr_name)
168+
path = os.path.join(OUTPUT_DIR, "tier1_fuse_lossless.png")
169+
img.save(path)
170+
print(f" Saved: {path}")
171+
gc.collect()
172+
torch.cuda.empty_cache()
139173

140-
gc.collect()
141-
torch.cuda.empty_cache()
174+
# Tier 2: Kronecker split (default)
175+
if 2 in args.tiers:
176+
img = benchmark_tier2_kronecker_split(pipe, args.prompt, args.seed, args.lokr_path, args.lokr_name)
177+
path = os.path.join(OUTPUT_DIR, "tier2_kronecker.png")
178+
img.save(path)
179+
print(f" Saved: {path}")
180+
gc.collect()
181+
torch.cuda.empty_cache()
142182

143-
# Path B: Lossy LoRA via SVD at various ranks
144-
if not args.skip_lossy:
183+
# Tier 3: SVD to LoRA at various ranks
184+
if 3 in args.tiers:
145185
for rank in args.ranks:
146-
img = benchmark_lossy(pipe, args.prompt, args.seed, rank, args.lokr_path, args.lokr_name)
147-
path = os.path.join(OUTPUT_DIR, f"lora_svd_rank{rank}.png")
186+
img = benchmark_tier3_svd(pipe, args.prompt, args.seed, rank, args.lokr_path, args.lokr_name)
187+
path = os.path.join(OUTPUT_DIR, f"tier3_svd_rank{rank}.png")
148188
img.save(path)
149189
print(f" Saved: {path}")
150-
151190
gc.collect()
152191
torch.cuda.empty_cache()
153192

154193
print(f"\nAll results saved to {OUTPUT_DIR}/")
155-
print("Compare: baseline.png vs lokr_lossless.png vs lora_svd_rank*.png")
194+
print("Compare: baseline.png vs tier1_fuse_lossless.png vs tier2_kronecker.png vs tier3_svd_rank*.png")
156195

157196

158197
if __name__ == "__main__":

src/diffusers/loaders/lora_conversion_utils.py

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2688,11 +2688,15 @@ def _split_lokr_qkv(w1, w2, target_keys, factor):
26882688
return result
26892689

26902690

2691-
def _convert_non_diffusers_flux2_lokr_to_diffusers(state_dict):
2691+
def _convert_non_diffusers_flux2_lokr_to_diffusers(state_dict, fuse_qkv=False):
26922692
"""Convert BFL-format Flux2 LoKR state dict to peft-compatible diffusers format.
26932693
2694-
Handles fused QKV by splitting via Kronecker re-factorization (Van Loan algorithm).
2695-
Non-QKV modules are remapped directly. Alpha scaling is baked into lokr_w1.
2694+
Args:
2695+
state_dict: BFL-format LoKR state dict with ``diffusion_model.`` prefix.
2696+
fuse_qkv: If True, map fused QKV directly to ``to_qkv``/``to_added_qkv`` targets
2697+
(lossless, but requires the model's QKV to be fused before injection).
2698+
If False (default), split fused QKV into separate Q/K/V via Kronecker
2699+
re-factorization (slightly lossy, no model fusion needed).
26962700
"""
26972701
converted_state_dict = {}
26982702

@@ -2793,11 +2797,17 @@ def _remap_lokr_qkv(bfl_path, target_keys):
27932797
tb = f"transformer_blocks.{dl}"
27942798
db = f"double_blocks.{dl}"
27952799

2796-
# Split fused QKV into separate Q/K/V via Kronecker re-factorization
2797-
_remap_lokr_qkv(f"{db}.img_attn.qkv", [f"{tb}.attn.to_q", f"{tb}.attn.to_k", f"{tb}.attn.to_v"])
2798-
_remap_lokr_qkv(
2799-
f"{db}.txt_attn.qkv", [f"{tb}.attn.add_q_proj", f"{tb}.attn.add_k_proj", f"{tb}.attn.add_v_proj"]
2800-
)
2800+
if fuse_qkv:
2801+
# Lossless: map directly to fused targets (caller must fuse model QKV first)
2802+
_remap_lokr_module(f"{db}.img_attn.qkv", f"{tb}.attn.to_qkv")
2803+
_remap_lokr_module(f"{db}.txt_attn.qkv", f"{tb}.attn.to_added_qkv")
2804+
else:
2805+
# Split fused QKV into separate Q/K/V via Kronecker re-factorization
2806+
_remap_lokr_qkv(f"{db}.img_attn.qkv", [f"{tb}.attn.to_q", f"{tb}.attn.to_k", f"{tb}.attn.to_v"])
2807+
_remap_lokr_qkv(
2808+
f"{db}.txt_attn.qkv",
2809+
[f"{tb}.attn.add_q_proj", f"{tb}.attn.add_k_proj", f"{tb}.attn.add_v_proj"],
2810+
)
28012811

28022812
# Projections
28032813
_remap_lokr_module(f"{db}.img_attn.proj", f"{tb}.attn.to_out.0")

src/diffusers/loaders/lora_pipeline.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5648,6 +5648,7 @@ def lora_state_dict(
56485648
weight_name = kwargs.pop("weight_name", None)
56495649
use_safetensors = kwargs.pop("use_safetensors", None)
56505650
return_lora_metadata = kwargs.pop("return_lora_metadata", False)
5651+
fuse_qkv = kwargs.pop("fuse_qkv", False)
56515652

56525653
allow_pickle = False
56535654
if use_safetensors is None:
@@ -5691,14 +5692,16 @@ def lora_state_dict(
56915692
is_lokr = any("lokr_" in k for k in state_dict)
56925693
if is_lokr:
56935694
if any(k.startswith("diffusion_model.") for k in state_dict):
5694-
state_dict = _convert_non_diffusers_flux2_lokr_to_diffusers(state_dict)
5695+
state_dict = _convert_non_diffusers_flux2_lokr_to_diffusers(state_dict, fuse_qkv=fuse_qkv)
56955696
elif any(k.startswith("lycoris_") for k in state_dict):
56965697
state_dict = _convert_lycoris_flux2_lokr_to_diffusers(state_dict)
56975698
else:
56985699
state_dict = _convert_diffusers_flux2_lokr_to_peft(state_dict)
56995700
if metadata is None:
57005701
metadata = {}
57015702
metadata["is_lokr"] = "true"
5703+
if fuse_qkv:
5704+
metadata["fuse_qkv"] = "true"
57025705
else:
57035706
is_ai_toolkit = any(k.startswith("diffusion_model.") for k in state_dict)
57045707
if is_ai_toolkit:
@@ -5740,6 +5743,10 @@ def load_lora_weights(
57405743

57415744
transformer = getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer
57425745

5746+
# Fuse model QKV projections before injection if requested (lossless path for BFL LoKR)
5747+
if metadata and metadata.get("fuse_qkv") == "true":
5748+
transformer.fuse_qkv_projections()
5749+
57435750
self.load_lora_into_transformer(
57445751
state_dict,
57455752
transformer=transformer,

0 commit comments

Comments
 (0)