-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathget_model.py
More file actions
148 lines (134 loc) · 7.85 KB
/
get_model.py
File metadata and controls
148 lines (134 loc) · 7.85 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
import timm
from timm.data import resolve_data_config
from timm.data.transforms_factory import create_transform
import torch
from models.fastvit.fastvit import fastvit_sa24, fastvit_sa36
from models.fastvit.fastvit_gradCAM import FastvitGradCAM
from models.fastvit.fastvit_utils import replace_depthwise_dcls_fastvit, transform_fastvit
from utils import device
from models.convnext.convnext import convnext_small, convnext_base, convnext_tiny
from models.convnext.convnext_utils import replace_depthwise_dcls_cnvnxt, transform_convnext
from models.convnext.convnext_gradCAM import ConvNextGradCAM
import copy
from models.resnet.resnet_utils import replace_depthwise_dcls_resnet
from models.resnet.resnet_gradCAM import resnetGradCAM
from models.metaformer.metaformer_baselines import convformer_s18, caformer_s18
from models.metaformer.metaformers_utils import replace_depthwise_dcls_metaformers, meta_transform
from models.metaformer.metaformers_gradCAM import MetaformerGradCAM
def get_model(model_name, dcls_equipped=True, pretrained=True):
if model_name == "fastvit_sa36":
if dcls_equipped:
model = fastvit_sa36(pretrained=False)
model = replace_depthwise_dcls_fastvit(copy.deepcopy(model),
dilated_kernel_size=17,
kernel_count=34, version='v1')
if pretrained:
checkpoint = torch.hub.load_state_dict_from_url(
url='https://zenodo.org/records/8370737/files/fastvit_sa36_v1_17lim_34el_seed0.pth.tar',
map_location=device, check_hash=True)
model_state_dict = checkpoint['state_dict_ema']
model.load_state_dict(model_state_dict, strict=True)
else:
model = fastvit_sa36(pretrained=pretrained).to(device)
model_cam = FastvitGradCAM(model)
return model, model_cam, transform_fastvit
if model_name == "fastvit_sa24":
if dcls_equipped:
model = fastvit_sa24(pretrained=False).to(device)
model = replace_depthwise_dcls_fastvit(copy.deepcopy(model),
dilated_kernel_size=17,
kernel_count=34, version='v1')
if pretrained:
checkpoint = torch.hub.load_state_dict_from_url(
url='https://zenodo.org/records/8370737/files/fastvit_sa24_v1_17lim_34el_seed0.pth.tar',
map_location="cpu", check_hash=True)
model_state_dict = checkpoint['state_dict_ema']
model.load_state_dict(model_state_dict, strict=True)
else:
model = fastvit_sa24(pretrained=pretrained).to(device)
model_cam = FastvitGradCAM(model)
return model, model_cam, transform_fastvit
if model_name == "convnext_tiny":
if dcls_equipped:
model = convnext_tiny(pretrained=False).to(device)
model = replace_depthwise_dcls_cnvnxt(copy.deepcopy(model), dilated_kernel_size=17, kernel_count=34,
version='v0').to(device)
if pretrained:
url = "https://zenodo.org/record/7112021/files/convnext_dcls_tiny_1k_224_ema.pth"
checkpoints = torch.hub.load_state_dict_from_url(url=url, map_location=device, check_hash=True)
model.load_state_dict(checkpoints['model'])
else:
model = convnext_tiny(pretrained=pretrained).to(device)
model_cam = ConvNextGradCAM(model)
return model, model_cam, transform_convnext
if model_name == "convnext_small":
if dcls_equipped:
model = convnext_small(pretrained=False).to(device)
model = replace_depthwise_dcls_cnvnxt(copy.deepcopy(model), dilated_kernel_size=17, kernel_count=40,
version='v0').to(device)
if pretrained:
url = "https://zenodo.org/records/7112021/files/convnext_dcls_small_1k_224_ema.pth"
checkpoints = torch.hub.load_state_dict_from_url(url=url, map_location=device, check_hash=True)
model.load_state_dict(checkpoints["model"])
else:
model = convnext_small(pretrained=pretrained).to(device)
model_cam = ConvNextGradCAM(model)
return model, model_cam, transform_convnext
if model_name == "convnext_base":
if dcls_equipped:
model = convnext_base(pretrained=False).to(device)
model = replace_depthwise_dcls_cnvnxt(copy.deepcopy(model), dilated_kernel_size=17, kernel_count=40,
version='v0').to(device)
if pretrained:
url = "https://zenodo.org/record/7112021/files/convnext_dcls_base_1k_224_ema.pth"
checkpoints = torch.hub.load_state_dict_from_url(url=url, map_location=device, check_hash=True)
model.load_state_dict(checkpoints['model'])
else:
model = convnext_base(pretrained=pretrained).to(device)
model_cam = ConvNextGradCAM(model)
return model, model_cam, transform_convnext
if model_name == "caformer_s18":
if dcls_equipped:
model = caformer_s18(pretrained=False).to(device)
model = replace_depthwise_dcls_metaformers(copy.deepcopy(model),
dilated_kernel_size=17,
kernel_count=34, version='v1').to(device)
if pretrained:
url = "https://zenodo.org/records/8370737/files/caformer_s18_v1_17lim_34el_seed0.pth.tar?download=1"
checkpoints = torch.hub.load_state_dict_from_url(url=url, map_location=device, check_hash=True)
model.load_state_dict(checkpoints['state_dict_ema'], strict=False)
else:
model = caformer_s18(pretrained=pretrained).to(device)
model_cam = MetaformerGradCAM(model)
return model, model_cam, meta_transform
if model_name == "convformer_s18":
if dcls_equipped:
model = convformer_s18(pretrained=False).to(device)
model = replace_depthwise_dcls_metaformers(copy.deepcopy(model),
dilated_kernel_size=17,
kernel_count=40, version='v1').to(device)
if pretrained:
url = "https://zenodo.org/records/8370737/files/convformer_s18_v1_17lim_40el_seed0.pth.tar"
checkpoints = torch.hub.load_state_dict_from_url(url=url, map_location=device, check_hash=True)
model.load_state_dict(checkpoints['state_dict_ema'], strict=True)
else:
model = convformer_s18(pretrained=pretrained).to(device)
model_cam = MetaformerGradCAM(model)
return model, model_cam, meta_transform
if model_name == "resnet50":
if dcls_equipped:
model = timm.create_model('resnet50', pretrained=False).to(device)
# (ii) the preprocessing
model = replace_depthwise_dcls_resnet(copy.deepcopy(model),
dilated_kernel_size=7,
kernel_count=5, version='v0').to(device)
if pretrained:
url = "https://zenodo.org/records/8373830/files/resnet_dcls_kernel_5_model_best.pth.tar"
checkpoints = torch.hub.load_state_dict_from_url(url=url, map_location=device, check_hash=True)
model.load_state_dict(checkpoints['state_dict_ema'])
else:
model = timm.create_model('resnet50', pretrained=pretrained).to(device)
model_cam = resnetGradCAM(model)
config = resolve_data_config({}, model=model)
transform = create_transform(**config)
return model, model_cam, transform