-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathquery_model_layers.py
More file actions
60 lines (46 loc) · 1.76 KB
/
query_model_layers.py
File metadata and controls
60 lines (46 loc) · 1.76 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
"""
VITAL / zer0int: https://github.com/zer0int/VITAL-CLIP-class-concept-visualization
Use this script to figure out what to put in inner_neurons_fvis.py
...Especially for CNN layers -> layer: str = "layer4_2"
Example output of this script:
Model: ResNet (CNN), RN50x4
Total layers: 4
layer1 (0-3)
-> Channels: 0-319
layer2 (0-5)
-> Channels: 0-639
layer3 (0-9)
-> Channels: 0-1279
layer4 (0-5)
-> Channels: 0-2559
-> layer4_0 to layer4_5 is valid, for example.
For ViT blocks, it's simple num blocks, e.g.:
Model: ViT, zer0int/CLIP-Regression-ViT-L-14
Total blocks: 24 (0-23)
"""
import class_fvis.oaiclip as clip
from class_fvis.utils_clip_loader.clip_anything_to_openai import load_openai_clip_anything
import warnings # stop torch spam
warnings.filterwarnings("ignore", category=FutureWarning)
warnings.filterwarnings("ignore", category=UserWarning)
warnings.filterwarnings("ignore", category=DeprecationWarning)
clip_model = "RN50x4" # or e.g. zer0int/CLIP-Regression-ViT-L-14
model, _, _ = load_openai_clip_anything(clip, clip_model)
if clip_model.startswith("RN"):
visual = model.visual
stages = [(name, module) for name, module in visual.named_children()
if name.startswith('layer')]
print(f"\nModel: ResNet (CNN), {clip_model}")
print(f"Total layers: {len(stages)}")
for name, module in stages:
print(f"{name} (0-{len(module) - 1})")
if hasattr(module[0], 'conv3'):
channels = module[0].conv3.out_channels
else:
channels = "Unknown"
print(f" -> Channels: 0-{channels-1}")
else:
resblocks = model.visual.transformer.resblocks
num_blocks = len(resblocks)
print(f"\nModel: ViT, {clip_model}")
print(f"Total blocks: {num_blocks} (0-{num_blocks - 1})")