feat: token merging for image classification#537
Conversation
llcnt
left a comment
There was a problem hiding this comment.
Thanks a lot for the nice contribution :):)
Could you provide a working example of this new integration please (eg. a script or a notebook)? I tried to run it with a ViTForImageClassification but it fails (see comment below). I tried to run it with a pipeline from transformers, now the smashing works, but the inference fails: the base pipeline can accept str (url of images) or raw images. But the smashed pipeline can not. It would be nice to fix this so that the base model and the smashed one behave similarly. I tried to follow what you did in the new test by preprocessing the image before feeding it into the smashed pipeline, but I still get the error: TypeError: ViTAttention.forward() got an unexpected keyword argument 'output_attentions'.
Also could you fix the cursot[bot] comments (some of them are quite relevant ;) ) ?
Thx in advance!
|
This PR has been inactive for 10 days and is now marked as stale. |
|
Hi, thanks a lot for the feedback! I am working on the issues raised by Bugbot and will provide you shortly a notebook with a basic example on how to test the algorithm |
a6a79ec to
be3a6ed
Compare
|
Here you can find a notebook to test the algorithm on HF models and pipelines. |
|
This PR has been inactive for 10 days and is now marked as stale. |
|
bugbot run |
There was a problem hiding this comment.
Cursor Bugbot has reviewed your changes and found 1 potential issue.
Bugbot Autofix is OFF. To automatically fix reported issues with cloud agents, enable autofix in the Cursor dashboard.
Comment @cursor review or bugbot run to trigger another review on this PR
llcnt
left a comment
There was a problem hiding this comment.
Sorry for the delay, I was stuck with other tasks :(
Thanks for the updates, and for the notebook!
I was not able to run all the notebook though, because:
is_vitis not implemented. I guess you have added it lately, but did not re-run the notebook;TypeError: ViTAttention.forward() got an unexpected keyword argument 'output_attentions'pops out when I run inference of the smashed model. I guess because you are using an old version oftransformers. Can you print it ? And make sure it is compatible with newer version (4.56.0 is a good starting point I would say) ;)
Thank you again for your contribution:)
|
This PR has been inactive for 10 days and is now marked as stale. |
Not up to standards ⛔🔴 Issues
|
| Category | Results |
|---|---|
| Security | 3 high |
| Complexity | 1 medium |
🟢 Metrics 45 complexity · 0 duplication
Metric Results Complexity 45 Duplication 0
TIP This summary will be updated as you push new changes. Give us feedback
82ab9d2 to
dbcd5fd
Compare
|
Hi @llcnt, thank you for your patience and for the feedback! |
llcnt
left a comment
There was a problem hiding this comment.
Thanks for adding the missing fns :)
Now my notebook is running correctly! However when I run the base model (on 100 inputs tensors of shape [1, 3, 224, 224], it takes 1s), while the smashed_model takes >1s on the same inputs: so no speedup, it is even slower:( . Am I missing something?, do you still get the speedup on your side with transformer version ==4.57.6, and model "google/vit-large-patch16-224" ?
|
I think I know what's causing the slowdown. To make my first version work, I used only eager attention, while if I am not wrong, huggingface models use SDPA by default. I will adapt it to use the same attention functions that HF models use for better compatibility and run a proper benchmark on the original vs smashed model |
Would be super nice if you have time to look at this:) |
4b562ab to
cfcd717
Compare
|
Hi @llcnt, I added support for the HF attention mechanisms and rebased on I also ran a benchmark on performance (both inference speed and model accuracy) that you can run at the following script on my test branch. I am attaching the output of my benchmark for ViT-Base and ViT-Large and for
|
| Model | Batch | Original (ms) | ToMe (ms) | Original (img/s) | ToMe (img/s) | Speedup |
|---|---|---|---|---|---|---|
| vit-base-patch16-224 | 8 | 9.99 | 8.92 | 801.1 | 897.1 | 1.120x (+12.0%) |
| vit-base-patch16-224 | 32 | 35.96 | 30.83 | 889.9 | 1037.9 | 1.166x (+16.6%) |
| vit-base-patch16-224 | 128 | 138.91 | 125.87 | 921.4 | 1016.9 | 1.104x (+10.4%) |
| vit-large-patch16-224 | 8 | 33.46 | 19.91 | 239.1 | 401.8 | 1.681x (+68.1%) |
| vit-large-patch16-224 | 32 | 117.71 | 66.83 | 271.9 | 478.8 | 1.761x (+76.1%) |
| vit-large-patch16-224 | 128 | 452.98 | 266.94 | 282.6 | 479.5 | 1.697x (+69.7%) |
Part 2 — Quality Check (100 samples)
| Model | top-1 == top-1 | top-1 in top-5 | top-1 mismatches | top-5 misses |
|---|---|---|---|---|
| vit-base-patch16-224 | 100/100 PASS | 100/100 PASS | — | — |
| vit-large-patch16-224 | 96/100 WARN | 99/100 WARN | [8, 16, 46, 96] | [8] |
Part 3 — Attention Backend Comparison (batch_size=32)
| Model | Backend | dtype | Original (ms) | ToMe (ms) | Original (img/s) | ToMe (img/s) | Speedup |
|---|---|---|---|---|---|---|---|
| vit-base | sdpa | fp32 | 36.37 | 31.32 | 879.9 | 1021.7 | 1.161x (+16.1%) |
| vit-base | eager | fp32 | 39.60 | 31.35 | 808.0 | 1020.7 | 1.263x (+26.3%) |
| vit-base | flash_attention_2 | fp16 | 10.35 | 11.06 | 3092.3 | 2892.1 | 0.935x (-6.5%) |
| vit-large | sdpa | fp32 | 118.73 | 67.42 | 269.5 | 474.6 | 1.761x (+76.1%) |
| vit-large | eager | fp32 | 127.46 | 67.22 | 251.1 | 476.1 | 1.896x (+89.6%) |
| vit-large | flash_attention_2 | fp16 | 34.18 | 23.02 | 936.3 | 1390.1 | 1.485x (+48.5%) |
r=16
Part 1 — Speed Benchmark (sdpa, fp32)
| Model | Batch | Original (ms) | ToMe (ms) | Original (img/s) | ToMe (img/s) | Speedup |
|---|---|---|---|---|---|---|
| vit-base-patch16-224 | 8 | 10.07 | 6.72 | 794.7 | 1191.3 | 1.499x (+49.9%) |
| vit-base-patch16-224 | 32 | 36.23 | 21.57 | 883.2 | 1483.4 | 1.680x (+68.0%) |
| vit-base-patch16-224 | 128 | 140.06 | 84.18 | 913.9 | 1520.5 | 1.664x (+66.4%) |
| vit-large-patch16-224 | 8 | 33.79 | 11.93 | 236.7 | 670.7 | 2.833x (+183.3%) |
| vit-large-patch16-224 | 32 | 118.53 | 36.62 | 270.0 | 873.8 | 3.237x (+223.7%) |
| vit-large-patch16-224 | 128 | 455.24 | 140.48 | 281.2 | 911.2 | 3.241x (+224.1%) |
Part 2 — Quality Check (100 samples)
| Model | top-1 == top-1 | top-1 in top-5 | Notes |
|---|---|---|---|
| vit-base-patch16-224 | 99/100 WARN | 100/100 PASS | only sample 65 mismatches; still in top-5 |
| vit-large-patch16-224 | 0/100 FAIL | 4/100 FAIL | quality completely degraded — top-1 changes on every sample |
Part 3 — Attention Backend Comparison (batch_size=32)
| Model | Backend | dtype | Original (ms) | ToMe (ms) | Original (img/s) | ToMe (img/s) | Speedup |
|---|---|---|---|---|---|---|---|
| vit-base | sdpa | fp32 | 36.41 | 21.72 | 878.9 | 1473.0 | 1.676x (+67.6%) |
| vit-base | eager | fp32 | 39.63 | 22.55 | 807.5 | 1418.9 | 1.757x (+75.7%) |
| vit-base | flash_attention_2 | fp16 | 14.51 | 9.14 | 2205.3 | 3499.7 | 1.587x (+58.7%) |
| vit-large | sdpa | fp32 | 118.76 | 38.93 | 269.5 | 822.1 | 3.051x (+205.1%) |
| vit-large | eager | fp32 | 127.57 | 36.64 | 250.8 | 873.3 | 3.482x (+248.2%) |
| vit-large | flash_attention_2 | fp16 | 34.19 | 13.27 | 935.9 | 2411.5 | 2.577x (+157.7%) |
Wow!! looks so nice :) I will have a look as soon as possible! |
|
bugbot run |
Bugbot couldn't runBugbot is not enabled for your user on this team. Ask your team administrator to increase your team's hard limit for Bugbot seats or add you to the allowlist in the Cursor dashboard. |
llcnt
left a comment
There was a problem hiding this comment.
Thanks for the adaption to other attention backends! It now works and I do obtain some speedups with negligeable quality loss :):)
I have left some minor comments that should not take long to adress ;) After that, I think we will be very close to merge !
PS: this is not a comment for code change, but I just want to keep the numbers here for further references: when testing the vit-large model on a single H100 with r=8 and eager backend, I did not get speedgains for batch size [1, 4]. Starting from batch_size=6, I do obtain some speed improvements.
Co-authored-by: llcnt <73026329+llcnt@users.noreply.github.com>
|
Hi @llcnt , I integrated the changes you highlighted and added the support for Please let me know if you think anything else should be added before merging :) |
llcnt
left a comment
There was a problem hiding this comment.
Thank you again for the nice work, and all the feedbacks :)

Description
This PR introduces the Token Merging (ToMe) algorithm for HuggingFace Vision Transformer models. Token Merging progressively merges similar tokens between the attention and MLP stages of each transformer block, significantly reducing the number of tokens and speeding up inference with minimal quality loss.
Using model
google/vit-base-patch16-224, speedup is over 2x withr=8.Key Changes:
Token Merging Algorithm:
ToMeViTLayer,ToMeViTSelfAttention) that extend HuggingFace transformersTesting Infrastructure:
Related Issue
Fixes #399
Type of Change
How Has This Been Tested?
google/vit-base-patch16-224model familyImplementation Details
Token Merging Core Features:
racross all layersHyperparameters:
r(int, 0-128): Number of tokens to merge per layer (default: 16)trace_source(bool): Track merge provenance for visualizationprop_attn(bool): Enable proportional attention weighting (default: True)Checklist
Additional Notes
Design Decisions:
Module-level class definitions:
ToMeViTLayerand related classes are defined at module level (not inside methods) to ensure they are picklable for distributed training and model serialization.Eager attention enforcement: The
ToMeViTSelfAttentionclass uses eager attention computation to inject the proportional attention bias between QK matmul and softmax operations.Shared mutable state: All ToMe modules share a single
tome_infodict for efficient state management across layers.Future Enhancements:
References: