Skip to content

Commit 6b0ad8f

Browse files
kikoncuoclaudeBlaizzy
authored
fix: use alpha/rank scaling in LoRaLayer (standard LoRA convention) (#846)
* fix: use alpha/rank scaling in LoRaLayer (standard LoRA convention) LoRaLayer used raw `alpha` as the scaling factor instead of `alpha / rank`. With the default alpha=16, rank=8, this made the LoRA contribution 8x larger than PEFT, the original LoRA paper, and mlx-lm. Before: scale = alpha = 16.0 After: scale = alpha / rank = 2.0 Also fixes replace_lora_with_linear to use the same corrected scale. Added tests verifying: - scale = alpha / rank - Forward pass produces (alpha/rank) * (x @ A @ B) - Default settings give 2x scaling, not 16x Fixes #845 * test: add B=0 initialization test for LoRaLayer Verifies that when B is zeros (default init), the LoRA layer output equals the base linear layer output exactly (no LoRA contribution). Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * refactor: move LoRA scaling tests into test_trainer.py Move TestLoRaScaling class from test_trainer_utils.py into test_trainer.py as suggested in review, and revert test_trainer_utils.py to its original state. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> --------- Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com> Co-authored-by: Prince Canuma <prince.gdt@gmail.com>
1 parent e2e9e67 commit 6b0ad8f

2 files changed

Lines changed: 68 additions & 3 deletions

File tree

mlx_vlm/tests/test_trainer.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import mlx.nn as nn
66

77
from mlx_vlm.trainer.datasets import VisionDataset
8+
from mlx_vlm.trainer.lora import LoRaLayer
89
from mlx_vlm.trainer.sft_trainer import TrainingArgs, train
910

1011

@@ -142,6 +143,70 @@ def test_train_smoke(self, mock_save_safetensors, mock_iterate_batches):
142143
mock_save_safetensors.assert_called()
143144

144145

146+
class TestLoRaScaling(unittest.TestCase):
147+
"""Verify LoRaLayer uses alpha/rank scaling (standard LoRA convention)."""
148+
149+
def test_scale_is_alpha_over_rank(self):
150+
linear = nn.Linear(4, 4)
151+
lora = LoRaLayer(linear, rank=8, alpha=16.0)
152+
self.assertAlmostEqual(lora.scale, 2.0) # 16 / 8 = 2.0
153+
154+
def test_scale_with_rank_equals_alpha(self):
155+
linear = nn.Linear(4, 4)
156+
lora = LoRaLayer(linear, rank=4, alpha=4.0)
157+
self.assertAlmostEqual(lora.scale, 1.0) # 4 / 4 = 1.0
158+
159+
def test_forward_scaling_matches_peft(self):
160+
"""LoRA contribution should equal (alpha/rank) * (x @ A @ B)."""
161+
linear = nn.Linear(4, 4)
162+
lora = LoRaLayer(linear, rank=8, alpha=16.0, dropout=0.0)
163+
164+
# Set deterministic weights
165+
lora.A = mx.ones((4, 8))
166+
lora.B = mx.ones((8, 4))
167+
x = mx.ones((1, 4))
168+
169+
base_output = linear(x)
170+
actual_output = lora(x)
171+
lora_contribution = actual_output - base_output
172+
173+
# Expected: (alpha / rank) * (x @ A @ B) = 2.0 * (ones(1,4) @ ones(4,8) @ ones(8,4))
174+
# x @ A = 4 * ones(1,8), then @ B = 32 * ones(1,4), then * 2.0 = 64
175+
expected_per_element = 2.0 * 4.0 * 8.0 # 64.0
176+
self.assertAlmostEqual(
177+
lora_contribution[0, 0].item(), expected_per_element, places=1
178+
)
179+
180+
def test_b_zero_init_gives_no_lora_contribution(self):
181+
"""When B is zeros (default init), output should equal base linear."""
182+
linear = nn.Linear(4, 4)
183+
lora = LoRaLayer(linear, rank=8, alpha=16.0, dropout=0.0)
184+
# B is already zeros from __init__, don't override it
185+
x = mx.ones((1, 4))
186+
base_output = linear(x)
187+
lora_output = lora(x)
188+
self.assertTrue(mx.allclose(base_output, lora_output).item())
189+
190+
def test_default_alpha_rank_gives_2x(self):
191+
"""Default alpha=16, rank=8 should give 2x scaling, not 16x."""
192+
linear = nn.Linear(8, 8)
193+
lora = LoRaLayer(linear, rank=8, alpha=16.0, dropout=0.0)
194+
195+
lora.A = mx.ones((8, 8))
196+
lora.B = mx.ones((8, 8))
197+
x = mx.ones((1, 8))
198+
199+
base = linear(x)
200+
actual = lora(x)
201+
contribution = (actual - base)[0, 0].item()
202+
203+
raw_delta = (x @ lora.A @ lora.B)[0, 0].item() # 64.0
204+
205+
# Should be 2x the raw delta, not 16x
206+
self.assertAlmostEqual(contribution, 2.0 * raw_delta, places=1)
207+
self.assertNotAlmostEqual(contribution, 16.0 * raw_delta, places=1)
208+
209+
145210
if __name__ == "__main__":
146211
unittest.main()
147212

mlx_vlm/trainer/lora.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,19 +31,19 @@ def __init__(
3131
shape=(input_dims, rank),
3232
)
3333
self.B = mx.zeros((rank, output_dims))
34-
self.alpha = alpha
34+
self.scale = alpha / rank
3535

3636
def __call__(self, x):
3737
y = self.original_layer(x)
3838
lora_update = (self.dropout(x) @ self.A) @ self.B
39-
return y + (self.alpha * lora_update).astype(x.dtype)
39+
return y + (self.scale * lora_update).astype(x.dtype)
4040

4141

4242
def replace_lora_with_linear(model):
4343
for i, layer in enumerate(model.layers):
4444
if isinstance(layer, LoRaLayer):
4545
# Compute the final merged weight
46-
lora_update = layer.alpha * (layer.A @ layer.B)
46+
lora_update = layer.scale * (layer.A @ layer.B)
4747
updated_weight = layer.original_layer.weight + lora_update
4848
use_bias = layer.original_layer.bias is not None
4949

0 commit comments

Comments
 (0)