|
| 1 | +from typing import Any, List, Optional, Tuple, Union |
| 2 | + |
| 3 | +import torch |
| 4 | +import torch.nn as nn |
| 5 | +from transformers.models.clip.modeling_clip import CLIPEncoderLayer |
| 6 | + |
| 7 | + |
| 8 | +def parse_r(num_layers: int, r: Union[List[int], Tuple[int, float], int]) -> List[int]: |
| 9 | + """Copy from the TOME. https://github.com/facebookresearch/ToMe. |
| 10 | +
|
| 11 | + Process a constant r or r schedule into a list for use internally. |
| 12 | +
|
| 13 | + r can take the following forms: |
| 14 | + - int: A constant number of tokens per layer. |
| 15 | + - Tuple[int, float]: A pair of r, inflection. |
| 16 | + Inflection describes there the the reduction / layer should trend |
| 17 | + upward (+1), downward (-1), or stay constant (0). A value of (r, 0) |
| 18 | + is as providing a constant r. (r, -1) is what we describe in the paper |
| 19 | + as "decreasing schedule". Any value between -1 and +1 is accepted. |
| 20 | + - List[int]: A specific number of tokens per layer. For extreme granularity. |
| 21 | + """ |
| 22 | + inflect = 0 |
| 23 | + if isinstance(r, list): |
| 24 | + if len(r) < num_layers: |
| 25 | + r = r + [0] * (num_layers - len(r)) |
| 26 | + return list(r) |
| 27 | + elif isinstance(r, tuple): |
| 28 | + r, inflect = r |
| 29 | + |
| 30 | + min_val = int(r * (1.0 - inflect)) |
| 31 | + max_val = 2 * r - min_val |
| 32 | + step = (max_val - min_val) / (num_layers - 1) |
| 33 | + |
| 34 | + return [int(min_val + step * i) for i in range(num_layers)] |
| 35 | + |
| 36 | + |
| 37 | +def make_tome_class(transformer_class): |
| 38 | + class VisionZipTransformer(transformer_class): |
| 39 | + """ |
| 40 | + Modifications: |
| 41 | + - Initialize r, token size, and token sources. |
| 42 | + """ |
| 43 | + |
| 44 | + def forward(self, *args, **kwdargs) -> torch.Tensor: |
| 45 | + self._info['r'] = parse_r(len(self.vision_model.encoder.layers), self.r) |
| 46 | + # self._info["r"] = self.r |
| 47 | + |
| 48 | + self._info['size'] = None |
| 49 | + self._info['source'] = None |
| 50 | + |
| 51 | + return super().forward(*args, **kwdargs) |
| 52 | + |
| 53 | + return VisionZipTransformer |
| 54 | + |
| 55 | + |
| 56 | +def apply_info(model, dominant_num, contextual_num): |
| 57 | + |
| 58 | + VisionZipTransformer = make_tome_class(model.__class__) |
| 59 | + |
| 60 | + model.__class__ = VisionZipTransformer |
| 61 | + model.r = [0 for i in range(22)] + [1] + [0] |
| 62 | + |
| 63 | + model._info = { |
| 64 | + 'r': [model.r], |
| 65 | + 'dominant': dominant_num, |
| 66 | + 'contextual': contextual_num, |
| 67 | + } |
| 68 | + for module in model.modules(): |
| 69 | + if isinstance(module, CLIPEncoderLayer): |
| 70 | + module.self_attn.k_proj._info = model._info |
0 commit comments