Skip to content

Commit e45ebed

Browse files
committed
dart,fatev,vispruner for llava1.6,update sparsevlm
1 parent c038265 commit e45ebed

9 files changed

Lines changed: 395 additions & 123 deletions

File tree

configs/sparsification/methods/SparseVLM/sparsevlm.yml

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,7 @@ sparse:
1717
special:
1818
method: SparseVLM
1919
pruning_loc: [2, 6, 15]
20-
retained_tokens: 192
21-
prune_flag: True
20+
reduction_ratio: 0.6667
2221
merge_flag: True
2322
save:
2423
save_trans: False

llmc/compression/token_reduction/dart.py

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import functools
2-
import math
2+
from types import MethodType
33

44
import torch
55

@@ -24,26 +24,20 @@ def add_sparse_config(self):
2424
def register_reduction_modules(self):
2525

2626
@prefill_wrapper
27-
def vtoken_length_hook(module, input_args, pruning_paras):
28-
29-
input_ids = input_args[0]
27+
def vtoken_length_hook(module, args, pruning_paras):
28+
input_ids = args[0]
3029
token_indices = torch.where(
3130
input_ids[0] == pruning_paras['vision_token_index']
3231
)[0]
3332
pruning_paras['vision_token_length'] = token_indices.shape[0]
3433

35-
return input_args
36-
3734
@prefill_wrapper
3835
def get_any_states_hook(module, args, kwargs, layer_outs, pruning_paras, layer_idx):
39-
4036
past_key_value = kwargs['past_key_value']
4137
if past_key_value is None:
4238
raise ValueError('DART needs past_key_value but got None.')
4339
pruning_paras['any_states'] = past_key_value.key_cache[layer_idx]
4440

45-
return layer_outs
46-
4741
@prefill_wrapper
4842
def pruning_hook(module, args, kwargs, pruning_paras, normlayer):
4943

@@ -95,9 +89,17 @@ def pruning_hook(module, args, kwargs, pruning_paras, normlayer):
9589
return (hidden_states,), kwargs
9690

9791
if self.special_config['vision_token_length'] is None:
98-
self.model.embed_tokens.register_forward_pre_hook(
99-
functools.partial(vtoken_length_hook, pruning_paras=self.pruning_paras)
100-
)
92+
if self.model.__class__.__name__ == 'Llava':
93+
self.model.vlm_model.prepare_inputs_labels_for_multimodal = MethodType(
94+
self.vtoken_length_for_llava_hook(
95+
self.model.vlm_model.prepare_inputs_labels_for_multimodal,
96+
self.pruning_paras
97+
), self.model.vlm_model
98+
)
99+
else:
100+
self.model.embed_tokens.register_forward_pre_hook(
101+
functools.partial(vtoken_length_hook, pruning_paras=self.pruning_paras)
102+
)
101103

102104
self.blocks[self.pruning_loc - 1].register_forward_hook(
103105
functools.partial(

llmc/compression/token_reduction/fastv.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import functools
2+
from types import MethodType
23

34
import torch
45

@@ -104,9 +105,17 @@ def fastv_pruning_hook(module, args, kwargs, pruning_paras):
104105
return (hidden_states,), kwargs
105106

106107
if self.special_config['vision_token_length'] is None:
107-
self.model.embed_tokens.register_forward_pre_hook(
108-
functools.partial(vtoken_length_hook, pruning_paras=self.pruning_paras)
109-
)
108+
if self.model.__class__.__name__ == 'Llava':
109+
self.model.vlm_model.prepare_inputs_labels_for_multimodal = MethodType(
110+
self.vtoken_length_for_llava_hook(
111+
self.model.vlm_model.prepare_inputs_labels_for_multimodal,
112+
self.pruning_paras
113+
), self.model.vlm_model
114+
)
115+
else:
116+
self.model.embed_tokens.register_forward_pre_hook(
117+
functools.partial(vtoken_length_hook, pruning_paras=self.pruning_paras)
118+
)
110119

111120
self.blocks[self.pruning_loc - 1].register_forward_pre_hook(
112121
functools.partial(update_output_attentions_hook, pruning_paras=self.pruning_paras),

0 commit comments

Comments
 (0)