Fix C3ALinear.forward chaining active adapters through x instead of summing#3164
Conversation
…umming
In `C3ALinear.forward`, when multiple adapters are active the inner
loop rebinds `x` to the output of the current adapter:
for active_adapter in self.active_adapters:
...
c3a_kernel = self.c3a_kernel[active_adapter].to(torch.float32)
x = BlockCircularConvolution.apply(x, c3a_kernel) / x.size(-1)
result += x.to(result.dtype)
On the next iteration, `x` is no longer the original input but the
previous adapter's per-mention output, so each subsequent adapter's
`BlockCircularConvolution` is applied to that output instead of the
shared input. The adapters end up stacked as
result = W0 x + A1(x) + A2(A1(x)) + A3(A2(A1(x))) + ...
rather than
result = W0 x + A1(x) + A2(x) + A3(x) + ...
This does not match `merge`, which simply adds each adapter's delta
weight linearly:
base_layer.weight.data = base_layer.weight.data + delta_weight
so behaviour diverges between a merged and an unmerged multi-adapter
model, and the unmerged forward is demonstrably wrong on its own
terms (later adapters are applied to earlier adapters' low-rank
outputs).
Fix: bind the per-adapter result to a local `delta` instead of
reassigning `x`, so every adapter sees the original input `x`:
delta = BlockCircularConvolution.apply(x, c3a_kernel) / x.size(-1)
result += delta.to(result.dtype)
Behaviour is unchanged when at most one C3A adapter is active, which
matches the existing test coverage.
| c3a_kernel = self.c3a_kernel[active_adapter].to(torch.float32) | ||
| x = BlockCircularConvolution.apply(x, c3a_kernel) / x.size(-1) | ||
| result += x.to(result.dtype) | ||
| delta = BlockCircularConvolution.apply(x, c3a_kernel) / x.size(-1) |
This comment was marked as spam.
This comment was marked as spam.
Sorry, something went wrong.
BenjaminBossan
left a comment
There was a problem hiding this comment.
Thanks for identifying this issue. I agree that it's a bug. Two comments:
First, could you please add unit tests for this? We already have tests for multiple adapters and checking if merging produces the same result. However, C3A is missing there. Please add it to the test matrix here:
peft/tests/test_custom_models.py
Line 1117 in 5b6e5e9
Regarding the proposed solution: I think it works but what we could do instead is to sum all the active c3a_kernels and then send them through BlockCircularConvolution.apply instead of the other way round. This should get the same result and be more efficient. WDYT?
Also pinging @Phoveran
Bug
When more than one C3A adapter is active simultaneously, `C3ALinear.forward` feeds the previous adapter's output into the next adapter instead of feeding every adapter the original input. As a result, the unmerged forward pass produces a different value than the merged model for any multi-adapter configuration, and later adapters are spuriously applied on top of earlier adapters' low-rank outputs.
Root cause
In the active branch of `forward` (in `src/peft/tuners/c3a/layer.py`):
```python
result = self.base_layer(x, *args, **kwargs)
x = x.to(torch.float32)
for active_adapter in self.active_adapters:
if active_adapter not in self.c3a_kernel.keys():
continue
c3a_kernel = self.c3a_kernel[active_adapter].to(torch.float32)
x = BlockCircularConvolution.apply(x, c3a_kernel) / x.size(-1)
result += x.to(result.dtype)
```
The loop rebinds `x` to the current adapter's `BlockCircularConvolution(x, kernel_A) / x.size(-1)`. On the next iteration, that already-transformed `x` is passed into the next `BlockCircularConvolution`. For `N` active adapters the output becomes
```
W0·x + A1(x) + A2(A1(x)) + A3(A2(A1(x))) + ...
```
instead of the linear sum the merge path implements:
```python
merge():
base_layer.weight.data = base_layer.weight.data + delta_weight # one adapter at a time
```
which yields `W0·x + A1(x) + A2(x) + … + An(x)`. The unmerged and merged models therefore disagree whenever `len(active_adapters) > 1`, and the unmerged forward is incorrect on its own terms (each adapter should see the shared input, not the previous adapter's output).
Fix
Bind the per-adapter contribution to a local `delta` instead of reassigning `x`, so every adapter sees the original `x`:
```diff
```
When at most one C3A adapter is active (the path covered by existing tests) behaviour is identical; for multiple active adapters it now matches the merged model.