-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathmodules.py
More file actions
73 lines (62 loc) · 2.93 KB
/
modules.py
File metadata and controls
73 lines (62 loc) · 2.93 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import torch
import torch.nn as nn
import torch.nn.functional as F
import dino.vision_transformer as vits
class DinoFeaturizer(nn.Module):
def __init__(self,):
super().__init__()
self.dim = 70 #dim
patch_size = 16 #self.cfg.dino_patch_size
self.patch_size = patch_size
self.feat_type = 'feat'#self.cfg.dino_feat_type
arch = 'vit_small' #self.cfg.model_type
self.model = vits.__dict__[arch](
patch_size=patch_size,
num_classes=0)
for p in self.model.parameters():
p.requires_grad = False
self.model.eval().cuda()
self.dropout = torch.nn.Dropout2d(p=.1)
self.whetherdropout = False
if arch == "vit_small" and patch_size == 16:
url = "dino_deitsmall16_pretrain/dino_deitsmall16_pretrain.pth"
elif arch == "vit_small" and patch_size == 8:
url = "dino_deitsmall8_300ep_pretrain/dino_deitsmall8_300ep_pretrain.pth"
elif arch == "vit_base" and patch_size == 16:
url = "dino_vitbase16_pretrain/dino_vitbase16_pretrain.pth"
elif arch == "vit_base" and patch_size == 8:
url = "dino_vitbase8_pretrain/dino_vitbase8_pretrain.pth"
else:
raise ValueError("Unknown arch and patch size")
print("Since no pretrained weights have been provided, we load the reference pretrained DINO weights.")
state_dict = torch.hub.load_state_dict_from_url(url="https://dl.fbaipublicfiles.com/dino/" + url)
self.model.load_state_dict(state_dict, strict=True)
if arch == "vit_small":
self.n_feats = 384
else:
self.n_feats = 768
def forward(self, img, n=1, return_class_feat=False):
self.model.eval()
with torch.no_grad():
assert (img.shape[2] % self.patch_size == 0)
assert (img.shape[3] % self.patch_size == 0)
# get selected layer activations
feat, attn, qkv = self.model.get_intermediate_feat(img, n=n)
feat, attn, qkv = feat[0], attn[0], qkv[0]
feat_h = img.shape[2] // self.patch_size
feat_w = img.shape[3] // self.patch_size
if self.feat_type == "feat":
image_feat = feat[:, 1:, :].reshape(feat.shape[0], feat_h, feat_w, -1).permute(0, 3, 1, 2)
elif self.feat_type == "KK":
image_k = qkv[1, :, :, 1:, :].reshape(feat.shape[0], 6, feat_h, feat_w, -1)
B, H, I, J, D = image_k.shape
image_feat = image_k.permute(0, 1, 4, 2, 3).reshape(B, H * D, I, J)
else:
raise ValueError("Unknown feat type:{}".format(self.feat_type))
if return_class_feat:
return feat[:, :1, :].reshape(feat.shape[0], 1, 1, -1).permute(0, 3, 1, 2)
code = image_feat
if self.whetherdropout:
return self.dropout(image_feat), code
else:
return image_feat, code