Skip to content

Commit cfcd717

Browse files
committed
perf: minor optimizations
1 parent bad746b commit cfcd717

1 file changed

Lines changed: 12 additions & 7 deletions

File tree

src/pruna/algorithms/token_merging.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from __future__ import annotations
1515

1616
import math
17+
from collections import deque
1718
from typing import Any, Callable, Optional, Tuple
1819

1920
import torch
@@ -104,7 +105,7 @@ def _bipartite_soft_matching(
104105
return _do_nothing, _do_nothing
105106

106107
with torch.no_grad():
107-
tokens = tokens / tokens.norm(dim=-1, keepdim=True)
108+
tokens = torch.nn.functional.normalize(tokens, dim=-1)
108109
a, b = tokens[..., ::2, :], tokens[..., 1::2, :]
109110
scores = a @ b.transpose(-1, -2)
110111

@@ -157,6 +158,10 @@ def _merge_wavg(merge: Callable, x: torch.Tensor, size: torch.Tensor | None = No
157158
"""
158159
Merge via weighted average based on token size.
159160
161+
Concatenates ``x * size`` and ``size`` along the channel dimension and
162+
performs a *single* merge call instead of two, halving kernel-launch
163+
overhead.
164+
160165
Parameters
161166
----------
162167
merge : Callable
@@ -174,9 +179,10 @@ def _merge_wavg(merge: Callable, x: torch.Tensor, size: torch.Tensor | None = No
174179
if size is None:
175180
size = torch.ones_like(x[..., 0, None])
176181

177-
x = merge(x * size, mode="sum")
178-
size = merge(size, mode="sum")
179-
x = x / size
182+
# Single merge pass: cat weighted tokens and sizes, split after merge.
183+
combined = merge(torch.cat([x * size, size], dim=-1), mode="sum")
184+
size = combined[..., -1:]
185+
x = combined[..., :-1] / size
180186
return x, size
181187

182188

@@ -350,7 +356,7 @@ def forward(
350356
hidden_states = attention_output + hidden_states
351357

352358
# --- token merging ---
353-
r = self._tome_info["r"].pop(0)
359+
r = self._tome_info["r"].popleft()
354360
if r > 0:
355361
metric = self._tome_info["metric"]
356362
merge, _ = _bipartite_soft_matching(
@@ -433,8 +439,7 @@ def forward(self, *args: Any, **kwargs: Any) -> Any:
433439
Any
434440
The output of the wrapped model's forward pass.
435441
"""
436-
# Make a copy of the list to avoid modifying the original
437-
self._tome_info["r"] = list(self.parsed_r)
442+
self._tome_info["r"] = deque(self.parsed_r)
438443
self._tome_info["size"] = None
439444
self._tome_info["source"] = None
440445
self._tome_info["metric"] = None

0 commit comments

Comments
 (0)