-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathextract_features_e.py
More file actions
126 lines (104 loc) · 3.6 KB
/
extract_features_e.py
File metadata and controls
126 lines (104 loc) · 3.6 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
import os
os.environ["MAMBA_FORCE_FALLBACK"] = "1"
import torch
import torch.nn as nn
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from transformers import AutoModel
from transformers.modeling_utils import PreTrainedModel
if not hasattr(PreTrainedModel, "all_tied_weights_keys"):
@property
def all_tied_weights_keys(self):
tied = getattr(self, "_tied_weights_keys", None)
if isinstance(tied, dict):
return tied
if isinstance(tied, (list, tuple, set)):
return {k: None for k in tied}
return {}
PreTrainedModel.all_tied_weights_keys = all_tied_weights_keys
# =========================
# CONFIG
# =========================
data_dir = "/home/teaching/dl_mamba/data_e_clean" # E-dataset (full, 20 per subject)
model_path = "best_mamba_model.pth"
batch_size = 16
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# =========================
# TRANSFORM
# =========================
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406],
[0.229, 0.224, 0.225])
])
# =========================
# DATA
# =========================
dataset = datasets.ImageFolder(data_dir, transform=transform)
loader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
# =========================
# MODEL (same as training)
# =========================
class MambaClassifier(nn.Module):
def __init__(self, num_classes):
super().__init__()
self.backbone = AutoModel.from_pretrained(
"nvidia/MambaVision-T-1K",
trust_remote_code=True,
ignore_mismatched_sizes=True
)
if hasattr(self.backbone.config, "tie_word_embeddings"):
self.backbone.config.tie_word_embeddings = False
self.classifier = nn.LazyLinear(num_classes)
def forward(self, x):
outputs = self.backbone(x)
if hasattr(outputs, "last_hidden_state") and outputs.last_hidden_state is not None:
hidden = outputs.last_hidden_state
elif isinstance(outputs, dict):
hidden = outputs.get("last_hidden_state")
if hidden is None and len(outputs) > 0:
hidden = next(iter(outputs.values()))
elif isinstance(outputs, (tuple, list)) and len(outputs) > 0:
hidden = outputs[0]
else:
hidden = outputs
if isinstance(hidden, (tuple, list)) and len(hidden) > 0:
hidden = hidden[0]
if hidden.dim() == 4:
features = hidden.mean(dim=(2, 3))
elif hidden.dim() == 3:
features = hidden.mean(dim=1)
elif hidden.dim() == 2:
features = hidden
else:
features = hidden.flatten(start_dim=1)
return features
# =========================
# LOAD MODEL
# =========================
num_classes = len(dataset.classes)
model = MambaClassifier(num_classes)
model.load_state_dict(torch.load(model_path, map_location=device))
model.to(device)
model.eval()
print("Model loaded!")
# =========================
# EXTRACT
# =========================
all_features = []
all_labels = []
with torch.no_grad():
for images, labels in loader:
images = images.to(device)
feats = model(images).cpu()
all_features.append(feats)
all_labels.append(labels)
all_features = torch.cat(all_features)
all_labels = torch.cat(all_labels)
print("E features shape:", all_features.shape)
torch.save({
"features": all_features,
"labels": all_labels
}, "features_e.pt")
print("Saved features_e.pt")