Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions src/peft/tuners/adalora/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,19 @@

from .config import AdaLoraConfig
from .gptq import SVDQuantLinear
from .layer import AdaLoraLayer, RankAllocator, SVDLinear
from .layer import AdaLoraLayer, RankAllocator, SVDConv2d, SVDLinear
from .model import AdaLoraModel


__all__ = ["AdaLoraConfig", "AdaLoraLayer", "AdaLoraModel", "RankAllocator", "SVDLinear", "SVDQuantLinear"]
__all__ = [
"AdaLoraConfig",
"AdaLoraLayer",
"AdaLoraModel",
"RankAllocator",
"SVDConv2d",
"SVDLinear",
"SVDQuantLinear",
]


register_peft_method(
Expand Down
172 changes: 163 additions & 9 deletions src/peft/tuners/adalora/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from peft.tuners.lora import LoraLayer
from peft.tuners.tuners_utils import check_adapters_to_merge
from peft.utils import transpose
from peft.utils.integrations import _skip_init_on_device

from .config import AdaLoraConfig

Expand Down Expand Up @@ -70,15 +71,22 @@ def update_layer(self, adapter_name: str, r: int, lora_alpha: int, config: AdaLo
self.lora_E[adapter_name] = nn.Parameter(torch.randn(r, 1))
# Left singular vectors
self.lora_B[adapter_name] = nn.Parameter(torch.randn(self.out_features, r))
# The current rank
self.ranknum[adapter_name] = nn.Parameter(torch.randn(1), requires_grad=False)
self.ranknum[adapter_name].data.fill_(float(r))
self.ranknum[adapter_name].requires_grad = False
self.scaling[adapter_name] = lora_alpha if lora_alpha > 0 else float(r)
if init_lora_weights:
self.reset_lora_parameters(adapter_name)

self._move_adapter_to_device_of_base_layer(adapter_name)
# The current rank. ranknum is deterministic (always set to float(r)) and is not saved
# in the adapter state_dict, so under low_cpu_mem_usage=True it can't be restored from a
# checkpoint and must stay off the meta device. Two places need to skip the meta dispatch:
# 1) creation, otherwise the Parameter is registered on meta from the start;
# 2) the subsequent _move_adapter_to_device_of_base_layer call, because moving the
# ParameterDict entry re-registers it via Module.__setattr__ → register_parameter,
# which inside init_empty_weights would send the tensor back to meta.
with _skip_init_on_device():
self.ranknum[adapter_name] = nn.Parameter(torch.randn(1), requires_grad=False)
self.ranknum[adapter_name].data.fill_(float(r))
self.ranknum[adapter_name].requires_grad = False
self.scaling[adapter_name] = lora_alpha if lora_alpha > 0 else float(r)
if init_lora_weights:
self.reset_lora_parameters(adapter_name)

self._move_adapter_to_device_of_base_layer(adapter_name)
self.set_adapter(self.active_adapters, inference_mode=inference_mode)

def reset_lora_parameters(self, adapter_name):
Expand Down Expand Up @@ -193,6 +201,152 @@ def __repr__(self) -> str:
return "adalora." + rep


class SVDConv2d(nn.Module, AdaLoraLayer):
# SVD-based adaptation by a Conv2d layer
def __init__(
self,
base_layer: nn.Module,
adapter_name: str,
config: AdaLoraConfig,
r: int = 0,
lora_alpha: int = 1,
**kwargs,
) -> None:
super().__init__()
AdaLoraLayer.__init__(self, base_layer)
self.get_base_layer().weight.requires_grad = False
self._active_adapter = adapter_name
self.update_layer(adapter_name, r, lora_alpha, config=config)

def update_layer(self, adapter_name: str, r: int, lora_alpha: int, config: AdaLoraConfig, **kwargs) -> None:
lora_dropout = config.lora_dropout
init_lora_weights = config.init_lora_weights
inference_mode = config.inference_mode
if r < 0:
raise ValueError(f"`r` should be a positive integer or 0, but the value passed is {r}")

self.r[adapter_name] = r
self.lora_alpha[adapter_name] = lora_alpha
if lora_dropout > 0.0:
lora_dropout_layer = nn.Dropout(p=lora_dropout)
else:
lora_dropout_layer = nn.Identity()
self.lora_dropout[adapter_name] = lora_dropout_layer

base = self.get_base_layer()
k = base.kernel_size[0] * base.kernel_size[1]
# Right singular vectors: (r, in_channels * kH * kW)
self.lora_A[adapter_name] = nn.Parameter(torch.randn(r, self.in_features * k))
# Singular values: (r, 1)
self.lora_E[adapter_name] = nn.Parameter(torch.randn(r, 1))
# Left singular vectors: (out_channels, r)
self.lora_B[adapter_name] = nn.Parameter(torch.randn(self.out_features, r))
# Current rank — see note in AdaLoraLayer.update_layer about _skip_init_on_device.
with _skip_init_on_device():
self.ranknum[adapter_name] = nn.Parameter(torch.randn(1), requires_grad=False)
self.ranknum[adapter_name].data.fill_(float(r))
self.ranknum[adapter_name].requires_grad = False
self.scaling[adapter_name] = lora_alpha if lora_alpha > 0 else float(r)
if init_lora_weights:
self.reset_lora_parameters(adapter_name)

self._move_adapter_to_device_of_base_layer(adapter_name)
self.set_adapter(self.active_adapters, inference_mode=inference_mode)

def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = None) -> None:
"""
Merge the active adapter weights into the base weights

Args:
safe_merge (`bool`, *optional*):
If True, the merge operation will be performed in a copy of the original weights and check for NaNs
before merging the weights. This is useful if you want to check if the merge operation will produce
NaNs. Defaults to `False`.
adapter_names (`List[str]`, *optional*):
The list of adapter names that should be merged. If None, all active adapters will be merged. Defaults
to `None`.
"""
adapter_names = check_adapters_to_merge(self, adapter_names)
if not adapter_names:
return

for active_adapter in adapter_names:
base_layer = self.get_base_layer()
if active_adapter in self.lora_A.keys():
if safe_merge:
orig_weights = base_layer.weight.data.clone()
orig_weights += self.get_delta_weight(active_adapter)

if not torch.isfinite(orig_weights).all():
raise ValueError(
f"NaNs detected in the merged weights. The adapter {active_adapter} seems to be broken"
)

base_layer.weight.data = orig_weights
else:
base_layer.weight.data += self.get_delta_weight(active_adapter)
self.merged_adapters.append(active_adapter)

def unmerge(self) -> None:
"""
This method unmerges all merged adapter layers from the base weights.
"""
if not self.merged:
warnings.warn("Already unmerged. Nothing to do.")
return
while len(self.merged_adapters) > 0:
active_adapter = self.merged_adapters.pop()
if active_adapter in self.lora_A.keys():
self.get_base_layer().weight.data -= self.get_delta_weight(active_adapter)

def get_delta_weight(self, adapter) -> torch.Tensor:
delta = (
self.lora_B[adapter]
@ (self.lora_A[adapter] * self.lora_E[adapter])
* self.scaling[adapter]
/ (self.ranknum[adapter] + 1e-5)
)
return delta.reshape(self.get_base_layer().weight.shape)

def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor:
if self.disable_adapters:
if self.merged:
self.unmerge()
result = self.base_layer(x, *args, **kwargs)
elif self.merged:
result = self.base_layer(x, *args, **kwargs)
else:
result = self.base_layer(x, *args, **kwargs)
for active_adapter in self.active_adapters:
if active_adapter not in self.lora_A.keys():
continue
lora_A = self.lora_A[active_adapter]
lora_B = self.lora_B[active_adapter]
lora_E = self.lora_E[active_adapter]
dropout = self.lora_dropout[active_adapter]
scaling = self.scaling[active_adapter]
ranknum = self.ranknum[active_adapter] + 1e-5

x = self._cast_input_dtype(x, lora_A.dtype)
base = self.get_base_layer()
delta_w = (lora_B @ (lora_A * lora_E) * scaling / ranknum).reshape(base.weight.shape)
# Use in-place add so result keeps its original dtype (matches SVDLinear);
# otherwise result would be upcast to lora_A.dtype and break subsequent fp16/bf16 ops.
result += nn.functional.conv2d(
dropout(x),
delta_w,
stride=base.stride,
padding=base.padding,
dilation=base.dilation,
groups=base.groups,
)
return result

def __repr__(self) -> str:
rep = super().__repr__()
return "adalora." + rep


class RankAllocator:
"""
The RankAllocator for AdaLoraModel. Paper: https://openreview.net/pdf?id=lq62uWRJjiY
Expand Down
11 changes: 7 additions & 4 deletions src/peft/tuners/adalora/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from peft.utils.integrations import gather_params_ctx

from .gptq import SVDQuantLinear
from .layer import AdaLoraLayer, RankAllocator, SVDLinear
from .layer import AdaLoraLayer, RankAllocator, SVDConv2d, SVDLinear


class AdaLoraModel(LoraModel):
Expand Down Expand Up @@ -205,12 +205,15 @@ def _create_new_module(lora_config, adapter_name, target, device_map=None, **kwa
"Setting fan_in_fan_out to True."
)
lora_config.fan_in_fan_out = True
else:
elif not isinstance(target_base_layer, torch.nn.Conv2d):
raise ValueError(
f"Target module {target} is not supported. "
f"Currently, only `torch.nn.Linear` and `Conv1D` are supported."
f"Currently, only `torch.nn.Linear`, `Conv1D`, and `torch.nn.Conv2d` are supported."
)
new_module = SVDLinear(target, adapter_name, config=lora_config, **kwargs)
if isinstance(target_base_layer, torch.nn.Conv2d):
new_module = SVDConv2d(target, adapter_name, config=lora_config, **kwargs)
else:
new_module = SVDLinear(target, adapter_name, config=lora_config, **kwargs)

return new_module

Expand Down
Loading
Loading