Skip to content

feat: token merging for image classification#537

Merged
llcnt merged 28 commits intoPrunaAI:mainfrom
rensortino:feat/token-merging
May 4, 2026
Merged

feat: token merging for image classification#537
llcnt merged 28 commits intoPrunaAI:mainfrom
rensortino:feat/token-merging

Conversation

@rensortino
Copy link
Copy Markdown
Contributor

Description

This PR introduces the Token Merging (ToMe) algorithm for HuggingFace Vision Transformer models. Token Merging progressively merges similar tokens between the attention and MLP stages of each transformer block, significantly reducing the number of tokens and speeding up inference with minimal quality loss.

Using model google/vit-base-patch16-224, speedup is over 2x with r=8.

Key Changes:

Token Merging Algorithm:

  • Implements the ToMe algorithm adapted from facebook/ToMe paper
  • Custom ViT module classes (ToMeViTLayer, ToMeViTSelfAttention) that extend HuggingFace transformers
  • Supports proportional attention weighting based on merged token sizes
  • Bipartite soft matching for intelligent token pair selection
  • Configurable token reduction schedule with per-layer control
  • Model wrapper for state management across forward passes

Testing Infrastructure:

  • Added ViT model fixtures for comprehensive testing
  • Token Merging test class with validation scenarios

Related Issue

Fixes #399

Type of Change

  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • This change requires a documentation update

How Has This Been Tested?

  • Token Merging algorithm tested with HuggingFace ViT models
  • Test fixtures added for google/vit-base-patch16-224 model family
  • Integration tests verify proper token reduction and attention output handling
  • Validated compatibility with existing Pruna pipeline

Implementation Details

Token Merging Core Features:

  1. Bipartite Soft Matching: Intelligently selects which token pairs to merge based on key similarity
  2. Proportional Attention: Adjusts attention weights by the log of merged token sizes
  3. Configurable Reduction Schedule:
    • Constant r across all layers
    • Per-layer list specification
    • Inflection-based schedules (increasing/decreasing/constant)
  4. Class Swapping Pattern: Dynamically replaces HF module classes at runtime to inject ToMe behavior
  5. Metric Storage: Uses key layer mean as similarity metric for matching

Hyperparameters:

  • r (int, 0-128): Number of tokens to merge per layer (default: 16)
  • trace_source (bool): Track merge provenance for visualization
  • prop_attn (bool): Enable proportional attention weighting (default: True)

Checklist

  • My code follows the style guidelines of this project
  • I have performed a self-review of my own code
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Additional Notes

Design Decisions:

  1. Module-level class definitions: ToMeViTLayer and related classes are defined at module level (not inside methods) to ensure they are picklable for distributed training and model serialization.

  2. Eager attention enforcement: The ToMeViTSelfAttention class uses eager attention computation to inject the proportional attention bias between QK matmul and softmax operations.

  3. Shared mutable state: All ToMe modules share a single tome_info dict for efficient state management across layers.

Future Enhancements:

  • Extension to other transformer architectures (Flux, SAM, etc.)
  • Support for custom attention mechanisms

References:

Comment thread src/pruna/algorithms/token_merging.py Outdated
Comment thread src/pruna/algorithms/token_merging.py Outdated
Comment thread src/pruna/algorithms/token_merging.py Outdated
Comment thread tests/algorithms/testers/token_merging.py Outdated
@sdiazlor sdiazlor requested a review from llcnt February 17, 2026 14:38
Copy link
Copy Markdown
Collaborator

@llcnt llcnt left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for the nice contribution :):)
Could you provide a working example of this new integration please (eg. a script or a notebook)? I tried to run it with a ViTForImageClassification but it fails (see comment below). I tried to run it with a pipeline from transformers, now the smashing works, but the inference fails: the base pipeline can accept str (url of images) or raw images. But the smashed pipeline can not. It would be nice to fix this so that the base model and the smashed one behave similarly. I tried to follow what you did in the new test by preprocessing the image before feeding it into the smashed pipeline, but I still get the error: TypeError: ViTAttention.forward() got an unexpected keyword argument 'output_attentions'.
Also could you fix the cursot[bot] comments (some of them are quite relevant ;) ) ?
Thx in advance!

Comment thread src/pruna/algorithms/token_merging.py Outdated
Comment thread src/pruna/algorithms/token_merging.py Outdated
@github-actions
Copy link
Copy Markdown

github-actions Bot commented Mar 6, 2026

This PR has been inactive for 10 days and is now marked as stale.

@github-actions github-actions Bot added the stale label Mar 6, 2026
@rensortino
Copy link
Copy Markdown
Contributor Author

Hi, thanks a lot for the feedback! I am working on the issues raised by Bugbot and will provide you shortly a notebook with a basic example on how to test the algorithm

@rensortino rensortino force-pushed the feat/token-merging branch from a6a79ec to be3a6ed Compare March 6, 2026 17:23
@rensortino
Copy link
Copy Markdown
Contributor Author

Here you can find a notebook to test the algorithm on HF models and pipelines.

@github-actions github-actions Bot removed the stale label Mar 7, 2026
@github-actions
Copy link
Copy Markdown

This PR has been inactive for 10 days and is now marked as stale.

@github-actions github-actions Bot added the stale label Mar 17, 2026
@llcnt
Copy link
Copy Markdown
Collaborator

llcnt commented Mar 18, 2026

bugbot run

@llcnt llcnt removed the stale label Mar 18, 2026
Copy link
Copy Markdown

@cursor cursor Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cursor Bugbot has reviewed your changes and found 1 potential issue.

Fix All in Cursor

Bugbot Autofix is OFF. To automatically fix reported issues with cloud agents, enable autofix in the Cursor dashboard.

Comment @cursor review or bugbot run to trigger another review on this PR

Comment thread src/pruna/algorithms/token_merging.py Outdated
Copy link
Copy Markdown
Collaborator

@llcnt llcnt left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for the delay, I was stuck with other tasks :(
Thanks for the updates, and for the notebook!
I was not able to run all the notebook though, because:

  • is_vit is not implemented. I guess you have added it lately, but did not re-run the notebook;
  • TypeError: ViTAttention.forward() got an unexpected keyword argument 'output_attentions' pops out when I run inference of the smashed model. I guess because you are using an old version of transformers. Can you print it ? And make sure it is compatible with newer version (4.56.0 is a good starting point I would say) ;)

Thank you again for your contribution:)

@github-actions
Copy link
Copy Markdown

This PR has been inactive for 10 days and is now marked as stale.

@github-actions github-actions Bot added the stale label Mar 29, 2026
@sdiazlor sdiazlor removed the stale label Mar 30, 2026
@codacy-production
Copy link
Copy Markdown

codacy-production Bot commented Apr 7, 2026

Not up to standards ⛔

🔴 Issues 3 high · 1 medium

Alerts:
⚠ 4 issues (≤ 0 issues of at least minor severity)

Results:
4 new issues

Category Results
Security 3 high
Complexity 1 medium

View in Codacy

🟢 Metrics 45 complexity · 0 duplication

Metric Results
Complexity 45
Duplication 0

View in Codacy

TIP This summary will be updated as you push new changes. Give us feedback

@rensortino rensortino force-pushed the feat/token-merging branch from 82ab9d2 to dbcd5fd Compare April 7, 2026 22:13
@rensortino
Copy link
Copy Markdown
Contributor Author

Hi @llcnt, thank you for your patience and for the feedback!
I rebased my branch on main and fixed the issue raised by Cursor.
I also updated the branch with the notebook and added the missing functions. I tested with transformers==4.57.6 and it works.
I also fixed a small issue in the test as I was loading the image from a local file, now it loads it from a HF dataset.

Copy link
Copy Markdown
Collaborator

@llcnt llcnt left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding the missing fns :)
Now my notebook is running correctly! However when I run the base model (on 100 inputs tensors of shape [1, 3, 224, 224], it takes 1s), while the smashed_model takes >1s on the same inputs: so no speedup, it is even slower:( . Am I missing something?, do you still get the speedup on your side with transformer version ==4.57.6, and model "google/vit-large-patch16-224" ?

Comment thread src/pruna/algorithms/token_merging.py Outdated
@llcnt llcnt requested a review from oskarkuuse April 10, 2026 10:23
@rensortino
Copy link
Copy Markdown
Contributor Author

I think I know what's causing the slowdown. To make my first version work, I used only eager attention, while if I am not wrong, huggingface models use SDPA by default. I will adapt it to use the same attention functions that HF models use for better compatibility and run a proper benchmark on the original vs smashed model

@llcnt
Copy link
Copy Markdown
Collaborator

llcnt commented Apr 20, 2026

I think I know what's causing the slowdown. To make my first version work, I used only eager attention, while if I am not wrong, huggingface models use SDPA by default. I will adapt it to use the same attention functions that HF models use for better compatibility and run a proper benchmark on the original vs smashed model

Would be super nice if you have time to look at this:)

@rensortino
Copy link
Copy Markdown
Contributor Author

Hi @llcnt, I added support for the HF attention mechanisms and rebased on main to integrate the latest changes.

I also ran a benchmark on performance (both inference speed and model accuracy) that you can run at the following script on my test branch.

I am attaching the output of my benchmark for ViT-Base and ViT-Large and for r=8 and r=16.

r=8


Part 1 — Speed Benchmark (sdpa, fp32)

Model Batch Original (ms) ToMe (ms) Original (img/s) ToMe (img/s) Speedup
vit-base-patch16-224 8 9.99 8.92 801.1 897.1 1.120x (+12.0%)
vit-base-patch16-224 32 35.96 30.83 889.9 1037.9 1.166x (+16.6%)
vit-base-patch16-224 128 138.91 125.87 921.4 1016.9 1.104x (+10.4%)
vit-large-patch16-224 8 33.46 19.91 239.1 401.8 1.681x (+68.1%)
vit-large-patch16-224 32 117.71 66.83 271.9 478.8 1.761x (+76.1%)
vit-large-patch16-224 128 452.98 266.94 282.6 479.5 1.697x (+69.7%)

Part 2 — Quality Check (100 samples)

Model top-1 == top-1 top-1 in top-5 top-1 mismatches top-5 misses
vit-base-patch16-224 100/100 PASS 100/100 PASS
vit-large-patch16-224 96/100 WARN 99/100 WARN [8, 16, 46, 96] [8]

Part 3 — Attention Backend Comparison (batch_size=32)

Model Backend dtype Original (ms) ToMe (ms) Original (img/s) ToMe (img/s) Speedup
vit-base sdpa fp32 36.37 31.32 879.9 1021.7 1.161x (+16.1%)
vit-base eager fp32 39.60 31.35 808.0 1020.7 1.263x (+26.3%)
vit-base flash_attention_2 fp16 10.35 11.06 3092.3 2892.1 0.935x (-6.5%)
vit-large sdpa fp32 118.73 67.42 269.5 474.6 1.761x (+76.1%)
vit-large eager fp32 127.46 67.22 251.1 476.1 1.896x (+89.6%)
vit-large flash_attention_2 fp16 34.18 23.02 936.3 1390.1 1.485x (+48.5%)

r=16


Part 1 — Speed Benchmark (sdpa, fp32)

Model Batch Original (ms) ToMe (ms) Original (img/s) ToMe (img/s) Speedup
vit-base-patch16-224 8 10.07 6.72 794.7 1191.3 1.499x (+49.9%)
vit-base-patch16-224 32 36.23 21.57 883.2 1483.4 1.680x (+68.0%)
vit-base-patch16-224 128 140.06 84.18 913.9 1520.5 1.664x (+66.4%)
vit-large-patch16-224 8 33.79 11.93 236.7 670.7 2.833x (+183.3%)
vit-large-patch16-224 32 118.53 36.62 270.0 873.8 3.237x (+223.7%)
vit-large-patch16-224 128 455.24 140.48 281.2 911.2 3.241x (+224.1%)

Part 2 — Quality Check (100 samples)

Model top-1 == top-1 top-1 in top-5 Notes
vit-base-patch16-224 99/100 WARN 100/100 PASS only sample 65 mismatches; still in top-5
vit-large-patch16-224 0/100 FAIL 4/100 FAIL quality completely degraded — top-1 changes on every sample

Part 3 — Attention Backend Comparison (batch_size=32)

Model Backend dtype Original (ms) ToMe (ms) Original (img/s) ToMe (img/s) Speedup
vit-base sdpa fp32 36.41 21.72 878.9 1473.0 1.676x (+67.6%)
vit-base eager fp32 39.63 22.55 807.5 1418.9 1.757x (+75.7%)
vit-base flash_attention_2 fp16 14.51 9.14 2205.3 3499.7 1.587x (+58.7%)
vit-large sdpa fp32 118.76 38.93 269.5 822.1 3.051x (+205.1%)
vit-large eager fp32 127.57 36.64 250.8 873.3 3.482x (+248.2%)
vit-large flash_attention_2 fp16 34.19 13.27 935.9 2411.5 2.577x (+157.7%)

@llcnt
Copy link
Copy Markdown
Collaborator

llcnt commented Apr 27, 2026

Hi @llcnt, I added support for the HF attention mechanisms and rebased on main to integrate the latest changes.

I also ran a benchmark on performance (both inference speed and model accuracy) that you can run at the following script on my test branch.

I am attaching the output of my benchmark for ViT-Base and ViT-Large and for r=8 and r=16.

r=8

Part 1 — Speed Benchmark (sdpa, fp32)

Model Batch Original (ms) ToMe (ms) Original (img/s) ToMe (img/s) Speedup
vit-base-patch16-224 8 9.99 8.92 801.1 897.1 1.120x (+12.0%)
vit-base-patch16-224 32 35.96 30.83 889.9 1037.9 1.166x (+16.6%)
vit-base-patch16-224 128 138.91 125.87 921.4 1016.9 1.104x (+10.4%)
vit-large-patch16-224 8 33.46 19.91 239.1 401.8 1.681x (+68.1%)
vit-large-patch16-224 32 117.71 66.83 271.9 478.8 1.761x (+76.1%)
vit-large-patch16-224 128 452.98 266.94 282.6 479.5 1.697x (+69.7%)

Part 2 — Quality Check (100 samples)

Model top-1 == top-1 top-1 in top-5 top-1 mismatches top-5 misses
vit-base-patch16-224 100/100 PASS 100/100 PASS — —
vit-large-patch16-224 96/100 WARN 99/100 WARN [8, 16, 46, 96] [8]

Part 3 — Attention Backend Comparison (batch_size=32)

Model Backend dtype Original (ms) ToMe (ms) Original (img/s) ToMe (img/s) Speedup
vit-base sdpa fp32 36.37 31.32 879.9 1021.7 1.161x (+16.1%)
vit-base eager fp32 39.60 31.35 808.0 1020.7 1.263x (+26.3%)
vit-base flash_attention_2 fp16 10.35 11.06 3092.3 2892.1 0.935x (-6.5%)
vit-large sdpa fp32 118.73 67.42 269.5 474.6 1.761x (+76.1%)
vit-large eager fp32 127.46 67.22 251.1 476.1 1.896x (+89.6%)
vit-large flash_attention_2 fp16 34.18 23.02 936.3 1390.1 1.485x (+48.5%)

r=16

Part 1 — Speed Benchmark (sdpa, fp32)

Model Batch Original (ms) ToMe (ms) Original (img/s) ToMe (img/s) Speedup
vit-base-patch16-224 8 10.07 6.72 794.7 1191.3 1.499x (+49.9%)
vit-base-patch16-224 32 36.23 21.57 883.2 1483.4 1.680x (+68.0%)
vit-base-patch16-224 128 140.06 84.18 913.9 1520.5 1.664x (+66.4%)
vit-large-patch16-224 8 33.79 11.93 236.7 670.7 2.833x (+183.3%)
vit-large-patch16-224 32 118.53 36.62 270.0 873.8 3.237x (+223.7%)
vit-large-patch16-224 128 455.24 140.48 281.2 911.2 3.241x (+224.1%)

Part 2 — Quality Check (100 samples)

Model top-1 == top-1 top-1 in top-5 Notes
vit-base-patch16-224 99/100 WARN 100/100 PASS only sample 65 mismatches; still in top-5
vit-large-patch16-224 0/100 FAIL 4/100 FAIL quality completely degraded — top-1 changes on every sample

Part 3 — Attention Backend Comparison (batch_size=32)

Model Backend dtype Original (ms) ToMe (ms) Original (img/s) ToMe (img/s) Speedup
vit-base sdpa fp32 36.41 21.72 878.9 1473.0 1.676x (+67.6%)
vit-base eager fp32 39.63 22.55 807.5 1418.9 1.757x (+75.7%)
vit-base flash_attention_2 fp16 14.51 9.14 2205.3 3499.7 1.587x (+58.7%)
vit-large sdpa fp32 118.76 38.93 269.5 822.1 3.051x (+205.1%)
vit-large eager fp32 127.57 36.64 250.8 873.3 3.482x (+248.2%)
vit-large flash_attention_2 fp16 34.19 13.27 935.9 2411.5 2.577x (+157.7%)

Wow!! looks so nice :) I will have a look as soon as possible!

@llcnt
Copy link
Copy Markdown
Collaborator

llcnt commented Apr 28, 2026

bugbot run

@cursor
Copy link
Copy Markdown

cursor Bot commented Apr 28, 2026

Bugbot couldn't run

Bugbot is not enabled for your user on this team.

Ask your team administrator to increase your team's hard limit for Bugbot seats or add you to the allowlist in the Cursor dashboard.

Copy link
Copy Markdown
Collaborator

@llcnt llcnt left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the adaption to other attention backends! It now works and I do obtain some speedups with negligeable quality loss :):)
I have left some minor comments that should not take long to adress ;) After that, I think we will be very close to merge !

PS: this is not a comment for code change, but I just want to keep the numbers here for further references: when testing the vit-large model on a single H100 with r=8 and eager backend, I did not get speedgains for batch size [1, 4]. Starting from batch_size=6, I do obtain some speed improvements.

Comment thread src/pruna/algorithms/token_merging.py Outdated
Comment thread src/pruna/algorithms/token_merging.py
Comment thread src/pruna/algorithms/token_merging.py
Comment thread src/pruna/algorithms/token_merging.py
Comment thread src/pruna/algorithms/token_merging.py
@rensortino
Copy link
Copy Markdown
Contributor Author

Hi @llcnt , I integrated the changes you highlighted and added the support for head_mask in the eager attention mechanism, along with a warning.

Please let me know if you think anything else should be added before merging :)

Copy link
Copy Markdown
Collaborator

@llcnt llcnt left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you again for the nice work, and all the feedbacks :)

@llcnt llcnt merged commit a0072a6 into PrunaAI:main May 4, 2026
5 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[FEATURE] Implement Token Merging

3 participants