Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
257 changes: 251 additions & 6 deletions modules/module/LoRAModule.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from modules.module.quantized.LinearSVD import BaseLinearSVD
from modules.util.config.TrainConfig import TrainConfig
from modules.util.enum.ModelType import PeftType
from modules.util.lokr_utils import factorization, make_kron, rebuild_tucker
from modules.util.ModuleFilter import ModuleFilter
from modules.util.quantization_util import get_unquantized_weight, get_weight_shape

Expand Down Expand Up @@ -284,6 +285,223 @@ def extract_from_module(self, base_module: nn.Module):
pass


class LoKrModule(PeftBase):
"""Implementation of LoKr from Lycoris."""

def __init__(self, prefix: str, orig_module: nn.Module | None, dim: int, alpha: float, decompose_both: bool, decompose_factor: int, use_tucker: bool, weight_decompose: bool, dora_on_output: bool, full_matrix: bool, train_device: torch.device, lokr_vec_trick: bool = False):
super().__init__(prefix, orig_module)
self.dim = dim # LoKr uses 'dim' as its parameter
self.dropout = Dropout(0)
self.register_buffer("alpha", torch.tensor(alpha))

self.decompose_both = decompose_both
self.decompose_factor = int(decompose_factor)
self.use_tucker = use_tucker
self.weight_decompose = weight_decompose
self.dora_on_output = dora_on_output
self.train_device = train_device
self.full_matrix = full_matrix
self.lokr_vec_trick = lokr_vec_trick

self.use_w1 = False
self.use_w2 = False
self.tucker = False
self.lokr_dora_scale = None

# Cache shapes for the forward pass
self.in_m = self.in_n = self.out_l = self.out_k = None

if orig_module is not None:
self.initialize_weights()
self.alpha = self.alpha.to(orig_module.weight.device)
self.alpha.requires_grad_(False)

def initialize_weights(self):
self._initialized = True
lokr_dim = self.dim
device = self.orig_module.weight.device

match self.orig_module:
case nn.Linear():
in_dim, out_dim = self.orig_module.in_features, self.orig_module.out_features
in_m, in_n = factorization(in_dim, self.decompose_factor)
out_l, out_k = factorization(out_dim, self.decompose_factor)
shape = ((out_l, out_k), (in_m, in_n))

# Store factorization shapes for the forward pass
self.in_m, self.in_n = in_m, in_n
self.out_l, self.out_k = out_l, out_k

# Create w1, or w1_a and w1_b
if self.decompose_both and lokr_dim < max(shape[0][0], shape[1][0]) / 2:
self.lokr_w1_a = Parameter(torch.empty(shape[0][0], lokr_dim, device=device))
self.lokr_w1_b = Parameter(torch.empty(lokr_dim, shape[1][0], device=device))
else:
self.use_w1 = True
self.lokr_w1 = Parameter(torch.empty(shape[0][0], shape[1][0], device=device))

# Create w2, or w2_a and w2_b
if not self.full_matrix and lokr_dim < max(shape[0][1], shape[1][1]) / 2:
self.lokr_w2_a = Parameter(torch.empty(shape[0][1], lokr_dim, device=device))
self.lokr_w2_b = Parameter(torch.empty(lokr_dim, shape[1][1], device=device))
else:
if not self.full_matrix:
print(f"LoKr rank {lokr_dim} is too large for dims ({in_dim}, {out_dim}) and factor {self.decompose_factor}, using full matrix mode.")
self.use_w2 = True
self.lokr_w2 = Parameter(torch.empty(shape[0][1], shape[1][1], device=device))

case nn.Conv2d():
in_dim, out_dim = self.orig_module.in_channels, self.orig_module.out_channels
k_size = self.orig_module.kernel_size
in_m, in_n = factorization(in_dim, self.decompose_factor)
out_l, out_k = factorization(out_dim, self.decompose_factor)
shape = ((out_l, out_k), (in_m, in_n), *k_size)
self.tucker = self.use_tucker and any(i != 1 for i in k_size)

# Create w1, or w1_a and w1_b
if self.decompose_both and lokr_dim < max(shape[0][0], shape[1][0]) / 2:
self.lokr_w1_a = Parameter(torch.empty(shape[0][0], lokr_dim, device=device))
self.lokr_w1_b = Parameter(torch.empty(lokr_dim, shape[1][0], device=device))
else:
self.use_w1 = True
self.lokr_w1 = Parameter(torch.empty(shape[0][0], shape[1][0], device=device))

# Create w2, or decomposed/tucker variants
if self.full_matrix or lokr_dim >= max(shape[0][1], shape[1][1]) / 2:
if not self.full_matrix:
print(f"LoKr rank {lokr_dim} is too large for dims ({in_dim}, {out_dim}) and factor {self.decompose_factor}, using full matrix mode.")
self.use_w2 = True
self.lokr_w2 = Parameter(torch.empty(shape[0][1], shape[1][1], *k_size, device=device))
elif self.tucker:
self.lokr_t2 = Parameter(torch.empty(lokr_dim, lokr_dim, *shape[2:], device=device))
self.lokr_w2_a = Parameter(torch.empty(lokr_dim, shape[0][1], device=device))
self.lokr_w2_b = Parameter(torch.empty(lokr_dim, shape[1][1], device=device))
else:
self.lokr_w2_a = Parameter(torch.empty(shape[0][1], lokr_dim, device=device))
self.lokr_w2_b = Parameter(torch.empty(lokr_dim, shape[1][1] * torch.tensor(shape[2:]).prod().item(), device=device))
case _:
raise NotImplementedError("Only Linear and Conv2d are supported layers.")

# Initialize weights
if self.use_w1:
nn.init.kaiming_uniform_(self.lokr_w1, a=math.sqrt(5))
else:
nn.init.kaiming_uniform_(self.lokr_w1_a, a=math.sqrt(5))
nn.init.kaiming_uniform_(self.lokr_w1_b, a=math.sqrt(5))

if self.use_w2:
nn.init.constant_(self.lokr_w2, 0)
else:
if self.tucker:
nn.init.kaiming_uniform_(self.lokr_t2, a=math.sqrt(5))
nn.init.kaiming_uniform_(self.lokr_w2_a, a=math.sqrt(5))
nn.init.constant_(self.lokr_w2_b, 0)

if self.weight_decompose:
if isinstance(self.orig_module, nn.Linear):
orig_weight = get_unquantized_weight(self.orig_module, torch.float, self.train_device)
else:
orig_weight = self.orig_module.weight.detach().float()

dora_num_dims = orig_weight.dim() - 1
if self.dora_on_output:
dora_scale_val = torch.norm(orig_weight.reshape(orig_weight.shape[0], -1), dim=1, keepdim=True).reshape(orig_weight.shape[0], *[1] * dora_num_dims)
else:
dora_scale_val = torch.norm(orig_weight.transpose(1, 0).reshape(orig_weight.shape[1], -1), dim=1, keepdim=True).reshape(orig_weight.shape[1], *[1] * dora_num_dims).transpose(0, 1)

self.lokr_dora_scale = Parameter(
dora_scale_val.to(device=self.orig_module.weight.device, dtype=self.orig_module.weight.dtype)
)
del orig_weight

def _get_factors(self):
"""Returns the two kronecker components W1 and W2."""
# If using DoRA (weight_decompose), we want clean weights here so we can
# apply dropout to the input 'x' later.
# If not using DoRA, we apply dropout to the internal factors here.
d = (lambda x: x) if self.weight_decompose else self.dropout

# Handle W1
w1 = d(self.lokr_w1) if self.use_w1 else d(self.lokr_w1_a) @ d(self.lokr_w1_b)

# Handle W2
if self.use_w2:
w2 = d(self.lokr_w2)
elif self.tucker:
w2 = rebuild_tucker(d(self.lokr_t2), d(self.lokr_w2_a), d(self.lokr_w2_b))
else:
w2 = d(self.lokr_w2_a) @ d(self.lokr_w2_b)

return w1, w2

def get_weight(self):
"""Computes the full LoKr weight matrix"""
w1, w2 = self._get_factors()
weight = make_kron(w1, w2.to(w1.dtype))
weight = weight.view(self.shape)

return weight

def forward(self, x, *args, **kwargs):
self.check_initialized()

scale = self.alpha / self.dim

# DoRA for LoKr
if self.weight_decompose:
Copy link
Copy Markdown
Collaborator

@dxqb dxqb May 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we need DoRA for LoKr?
my intuition: DoRA is already a niche technique, DoRA with LoKr even more so. but I could be wrong.

If we keep it, I see the following issues

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think it's of any use, but people use it with LoKr anyway. It can be removed


if isinstance(self.orig_module, nn.Linear):
orig_weight = get_unquantized_weight(self.orig_module, torch.float, self.train_device)
else:
orig_weight = self.orig_module.weight.detach().float()

delta_w = self.get_weight() * scale
wp = orig_weight + delta_w
del orig_weight

eps = torch.finfo(wp.dtype).eps
if self.dora_on_output:
norm = wp.detach().reshape(wp.shape[0], -1).norm(dim=1).reshape(wp.shape[0], *[1] * (wp.dim() - 1)) + eps
else:
norm = wp.detach().transpose(0, 1).reshape(wp.shape[1], -1).norm(dim=1, keepdim=True).reshape(wp.shape[1], *[1] * (wp.dim() - 1)).transpose(0, 1) + eps

wp = self.lokr_dora_scale * (wp / norm)

# Apply dropout to the input 'x' (DoRA style)
return self.op(self.dropout(x), wp.to(x.dtype), self.orig_module.bias, **self.layer_kwargs)
else:
if self.lokr_vec_trick and isinstance(self.orig_module, nn.Linear):
# Apply W1 and W2 sequentially via einsum instead of
# constructing the full Kronecker product.
w1, w2 = self._get_factors()

x_shape = x.shape
x_reshaped = x.reshape(-1, self.in_m, self.in_n)

# Calculate delta = x @ (W1 x W2).T
delta_output = torch.einsum(
'bmn, lm, kn -> blk',
x_reshaped, w1.to(x.dtype), w2.to(x.dtype)
)

# Reshape back to [Batch, ..., Out_Features]
delta_output = delta_output.reshape(*x_shape[:-1], -1) * scale

return self.orig_forward(x) + delta_output
else:
# Fallback for Conv2d layers or when lokr_vec_trick is disabled
w = self.get_weight() * scale
return self.orig_forward(x) + self.op(x, w.to(x.dtype), bias=None, **self.layer_kwargs)

def apply_to_module(self):
# TODO
pass

def extract_from_module(self, base_module: nn.Module):
# TODO
pass


class LoRAModule(PeftBase):
lora_down: nn.Module | None
lora_up: nn.Module | None
Expand Down Expand Up @@ -571,6 +789,7 @@ def forward(self, x, *args, **kwargs):
DummyDoRAModule = DoRAModule.make_dummy()
DummyLoHaModule = LoHaModule.make_dummy()
DummyOFTModule = OFTModule.make_dummy()
DummyLoKrModule = LoKrModule.make_dummy()


class LoRAModuleWrapper:
Expand All @@ -580,6 +799,7 @@ class LoRAModuleWrapper:
module_filters: list[ModuleFilter]

lora_modules: dict[str, PeftBase]
lokr_dim: int

def __init__(
self,
Expand All @@ -593,15 +813,15 @@ def __init__(
self.peft_type = config.peft_type
self.rank = config.lora_rank
self.alpha = config.lora_alpha
self.lokr_dim = config.lokr_dim

self.module_filters = [
ModuleFilter(pattern, use_regex=config.layer_filter_regex)
for pattern in (module_filter or [])
]

weight_decompose = config.lora_decompose
if self.peft_type == PeftType.LORA:
if weight_decompose:
if config.lora_decompose:
self.klass = DoRAModule
self.dummy_klass = DummyDoRAModule
self.additional_args = [self.rank, self.alpha]
Expand Down Expand Up @@ -632,7 +852,20 @@ def __init__(
self.additional_kwargs = {
'dropout_probability': config.dropout_probability,
}

elif self.peft_type == PeftType.LOKR:
self.klass = LoKrModule
self.dummy_klass = DummyLoKrModule
self.additional_args = [self.lokr_dim, self.alpha]
self.additional_kwargs = {
'decompose_both': config.lokr_decompose_both,
'decompose_factor': config.lokr_decompose_factor,
'use_tucker': config.lokr_use_tucker,
'weight_decompose': config.lokr_weight_decompose,
'dora_on_output': config.lokr_dora_on_output,
'full_matrix': config.lokr_full_matrix,
'train_device': torch.device(config.train_device),
'lokr_vec_trick': config.lokr_vec_trick,
}
self.lora_modules = self.__create_modules(orig_module, config)

def __create_modules(self, orig_module: nn.Module | None, config: TrainConfig) -> dict[str, PeftBase]:
Expand Down Expand Up @@ -714,9 +947,21 @@ def _check_rank_matches(self, state_dict: dict[str, Tensor]):
if self.peft_type == PeftType.OFT_2:
return

if rank_key := next((k for k in state_dict if k.endswith((".lora_down.weight", ".hada_w1_a"))), None):
if (checkpoint_rank := state_dict[rank_key].shape[0]) != self.rank:
raise ValueError(f"Rank mismatch: checkpoint={checkpoint_rank}, config={self.rank}, please correct in the UI.")
rank_keys = {
PeftType.LORA: ".lora_down.weight",
PeftType.LOHA: ".hada_w1_a",
PeftType.LOKR: ".lokr_w1_a",
}
key_suffix = rank_keys.get(self.peft_type)
if not key_suffix:
return

if rank_key := next((k for k in state_dict if k.endswith(key_suffix)), None):
checkpoint_rank = state_dict[rank_key].shape[1] if self.peft_type == PeftType.LOKR else state_dict[rank_key].shape[0]
config_rank = self.lokr_dim if self.peft_type == PeftType.LOKR else self.rank

if checkpoint_rank != config_rank:
raise ValueError(f"Rank/Dim mismatch: checkpoint={checkpoint_rank}, config={config_rank}, please correct in the UI.")

def load_state_dict(self, state_dict: dict[str, Tensor], strict: bool = True):
"""
Expand Down
65 changes: 65 additions & 0 deletions modules/ui/LoraTab.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,13 +39,16 @@ def refresh_ui(self):
("LoRA", PeftType.LORA),
("LoHa", PeftType.LOHA),
("OFT v2", PeftType.OFT_2),
("LoKr", PeftType.LOKR),
], self.ui_state, "peft_type", command=self.setup_lora)

def setup_lora(self, peft_type: PeftType):
if peft_type == PeftType.LOHA:
name = "LoHa"
elif peft_type == PeftType.OFT_2:
name = "OFT v2"
elif peft_type == PeftType.LOKR:
name = "LoKr"
else:
name = "LoRA"

Expand Down Expand Up @@ -152,3 +155,65 @@ def setup_lora(self, peft_type: PeftType):
components.label(master, 4, 0, "Bundle Embeddings",
tooltip=f"Bundles any additional embeddings into the {name} output file, rather than as separate files")
components.switch(master, 4, 1, self.ui_state, "bundle_additional_embeddings")

# LoKr
elif peft_type == PeftType.LOKR:
# LoKr Main Settings
components.label(master, 1, 0, f"{name} dimension",
tooltip="The dimension parameter used for the secondary decomposition. Analogous to rank in LoRA.")
components.entry(master, 1, 1, self.ui_state, "lokr_dim")

components.label(master, 2, 0, "Decomposition Factor",
tooltip="Factor for Kronecker product decomposition. -1 for auto, which is recommended. Changing this drastically affects parameter count.")
components.entry(master, 2, 1, self.ui_state, "lokr_decompose_factor")

# alpha
components.label(master, 3, 0, f"{name} alpha",
tooltip=f"The alpha parameter used when creating a new {name}")
components.entry(master, 3, 1, self.ui_state, "lora_alpha")

# Dropout Percentage
components.label(master, 4, 0, "Dropout Probability",
tooltip="Dropout probability. This percentage of model nodes will be randomly ignored at each training step. Helps with overfitting. 0 disables, 1 maximum.")
components.entry(master, 4, 1, self.ui_state, "dropout_probability")

# LoKr weight dtype
components.label(master, 5, 0, f"{name} Weight Data Type",
tooltip=f"The {name} weight data type used for training. This can reduce memory consumption, but reduces precision")
components.options_kv(master, 5, 1, [
("float32", DataType.FLOAT_32),
("bfloat16", DataType.BFLOAT_16),
], self.ui_state, "lora_weight_dtype")

# LoKr Vectorization trick
components.label(master, 6, 0, "Kronecker-Vec Trick",
tooltip="Uses an accelerated path that bypasses the materialization of the full Kronecker product. This delivers a massive speedup to the LoKr without sacrificing precision. Highly recommended.")
components.switch(master, 6, 1, self.ui_state, "lokr_vec_trick")

#LoKr Decomposition Settings
components.label(master, 1, 3, "Decompose Both Matrices",
tooltip="Perform rank decomposition on both Kronecker product matrices (W1 and W2). Only effective for very small dimensions.")
components.switch(master, 1, 4, self.ui_state, "lokr_decompose_both")

components.label(master, 2, 3, "Use Tucker Decomposition (Conv)",
tooltip="Use Tucker decomposition for convolutional layers. Can be more efficient for some architectures.")
components.switch(master, 2, 4, self.ui_state, "lokr_use_tucker")

components.label(master, 3, 3, "Force Full Matrix (W2)",
tooltip="Forces the second Kronecker matrix (W2) to be a full matrix, ignoring the dimension setting. For expert use.")
components.switch(master, 3, 4, self.ui_state, "lokr_full_matrix")

# LoKr DoRA Settings
components.label(master, 4, 3, "Decompose Weights (DoRA)",
tooltip="Apply weight decomposition (DoRA) on top of the LoKr update.")
components.switch(master, 4, 4, self.ui_state, "lokr_weight_decompose")

components.label(master, 5, 3, "Apply DoRA on Output Axis",
tooltip="Apply the DoRA weight decomposition on the output axis instead of the input axis.")
components.switch(master, 5, 4, self.ui_state, "lokr_dora_on_output")


# Additional embeddings
components.label(master, 6, 3, "Bundle Embeddings",
tooltip=f"Bundles any additional embeddings into the {name} output file, rather than as separate files")
components.switch(master, 6, 4, self.ui_state, "bundle_additional_embeddings")
Loading