1414from __future__ import annotations
1515
1616import math
17+ from collections import deque
1718from typing import Any , Callable , Optional , Tuple
1819
1920import torch
@@ -104,7 +105,7 @@ def _bipartite_soft_matching(
104105 return _do_nothing , _do_nothing
105106
106107 with torch .no_grad ():
107- tokens = tokens / tokens . norm ( dim = - 1 , keepdim = True )
108+ tokens = torch . nn . functional . normalize ( tokens , dim = - 1 )
108109 a , b = tokens [..., ::2 , :], tokens [..., 1 ::2 , :]
109110 scores = a @ b .transpose (- 1 , - 2 )
110111
@@ -157,6 +158,10 @@ def _merge_wavg(merge: Callable, x: torch.Tensor, size: torch.Tensor | None = No
157158 """
158159 Merge via weighted average based on token size.
159160
161+ Concatenates ``x * size`` and ``size`` along the channel dimension and
162+ performs a *single* merge call instead of two, halving kernel-launch
163+ overhead.
164+
160165 Parameters
161166 ----------
162167 merge : Callable
@@ -174,9 +179,10 @@ def _merge_wavg(merge: Callable, x: torch.Tensor, size: torch.Tensor | None = No
174179 if size is None :
175180 size = torch .ones_like (x [..., 0 , None ])
176181
177- x = merge (x * size , mode = "sum" )
178- size = merge (size , mode = "sum" )
179- x = x / size
182+ # Single merge pass: cat weighted tokens and sizes, split after merge.
183+ combined = merge (torch .cat ([x * size , size ], dim = - 1 ), mode = "sum" )
184+ size = combined [..., - 1 :]
185+ x = combined [..., :- 1 ] / size
180186 return x , size
181187
182188
@@ -350,7 +356,7 @@ def forward(
350356 hidden_states = attention_output + hidden_states
351357
352358 # --- token merging ---
353- r = self ._tome_info ["r" ].pop ( 0 )
359+ r = self ._tome_info ["r" ].popleft ( )
354360 if r > 0 :
355361 metric = self ._tome_info ["metric" ]
356362 merge , _ = _bipartite_soft_matching (
@@ -433,8 +439,7 @@ def forward(self, *args: Any, **kwargs: Any) -> Any:
433439 Any
434440 The output of the wrapped model's forward pass.
435441 """
436- # Make a copy of the list to avoid modifying the original
437- self ._tome_info ["r" ] = list (self .parsed_r )
442+ self ._tome_info ["r" ] = deque (self .parsed_r )
438443 self ._tome_info ["size" ] = None
439444 self ._tome_info ["source" ] = None
440445 self ._tome_info ["metric" ] = None
0 commit comments