Skip to content

Commit 31ea6c8

Browse files
committed
add vision zip
1 parent 9f8a2e9 commit 31ea6c8

5 files changed

Lines changed: 604 additions & 5 deletions

File tree

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
base:
2+
seed: &seed 42
3+
model:
4+
type: Llava
5+
path: model path
6+
torch_dtype: auto
7+
eval:
8+
eval_pos: [transformed]
9+
type: vqa
10+
name: [mme]
11+
download: False
12+
path: MME dataset path
13+
bs: 1
14+
inference_per_block: False
15+
sparse:
16+
method: TokenReduction
17+
special:
18+
method: VisionZip
19+
dominant: 191
20+
contextual: 30
21+
save:
22+
save_trans: False
23+
save_fake: False
24+
save_path: /path/to/save/

llmc/compression/token_reduction/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,3 +3,4 @@
33
from .fastv import FastV
44
from .sparsevlm import SparseVLM
55
from .tome import ToMe
6+
from .visionzip import VisionZip

llmc/compression/token_reduction/sparsevlm.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,7 @@
11
import functools
2-
import math
3-
import types
4-
from typing import Callable, Optional, Tuple
52

63
import einops as ein
74
import torch
8-
import torch.nn as nn
9-
import torch.nn.functional as F
105

116
from llmc.utils.registry_factory import TOKEN_REDUCTION_REGISTRY
127

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
from typing import Any, List, Optional, Tuple, Union
2+
3+
import torch
4+
import torch.nn as nn
5+
from transformers.models.clip.modeling_clip import CLIPEncoderLayer
6+
7+
8+
def parse_r(num_layers: int, r: Union[List[int], Tuple[int, float], int]) -> List[int]:
9+
"""Copy from the TOME. https://github.com/facebookresearch/ToMe.
10+
11+
Process a constant r or r schedule into a list for use internally.
12+
13+
r can take the following forms:
14+
- int: A constant number of tokens per layer.
15+
- Tuple[int, float]: A pair of r, inflection.
16+
Inflection describes there the the reduction / layer should trend
17+
upward (+1), downward (-1), or stay constant (0). A value of (r, 0)
18+
is as providing a constant r. (r, -1) is what we describe in the paper
19+
as "decreasing schedule". Any value between -1 and +1 is accepted.
20+
- List[int]: A specific number of tokens per layer. For extreme granularity.
21+
"""
22+
inflect = 0
23+
if isinstance(r, list):
24+
if len(r) < num_layers:
25+
r = r + [0] * (num_layers - len(r))
26+
return list(r)
27+
elif isinstance(r, tuple):
28+
r, inflect = r
29+
30+
min_val = int(r * (1.0 - inflect))
31+
max_val = 2 * r - min_val
32+
step = (max_val - min_val) / (num_layers - 1)
33+
34+
return [int(min_val + step * i) for i in range(num_layers)]
35+
36+
37+
def make_tome_class(transformer_class):
38+
class VisionZipTransformer(transformer_class):
39+
"""
40+
Modifications:
41+
- Initialize r, token size, and token sources.
42+
"""
43+
44+
def forward(self, *args, **kwdargs) -> torch.Tensor:
45+
self._info['r'] = parse_r(len(self.vision_model.encoder.layers), self.r)
46+
# self._info["r"] = self.r
47+
48+
self._info['size'] = None
49+
self._info['source'] = None
50+
51+
return super().forward(*args, **kwdargs)
52+
53+
return VisionZipTransformer
54+
55+
56+
def apply_info(model, dominant_num, contextual_num):
57+
58+
VisionZipTransformer = make_tome_class(model.__class__)
59+
60+
model.__class__ = VisionZipTransformer
61+
model.r = [0 for i in range(22)] + [1] + [0]
62+
63+
model._info = {
64+
'r': [model.r],
65+
'dominant': dominant_num,
66+
'contextual': contextual_num,
67+
}
68+
for module in model.modules():
69+
if isinstance(module, CLIPEncoderLayer):
70+
module.self_attn.k_proj._info = model._info

0 commit comments

Comments
 (0)