Skip to content

Commit 081af14

Browse files
Refactor label handling in Study_Models and update visualization process in todo.md
Fixed bugs for Study_Models.py
1 parent 1cdb442 commit 081af14

File tree

2 files changed

+7
-5
lines changed

2 files changed

+7
-5
lines changed

todo.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
For future version 4n2
22

33
```
4+
[2] Pred=0.438 → Credit Card Number
5+
[4] Pred=0.001 → Email Address
46
[6] Pred=0.000 → Private SSH key
57
[18] Pred=0.000 → Private RSA key
68
[30] Pred=0.003 → Credit card expiration

tools/Study_Models.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -169,9 +169,9 @@ def visualize_tsne(model_, dataloader_, filename="Visualize_tSNE.png", use_penul
169169
else:
170170
feat = model_(X)
171171
all_features.append(feat.cpu().numpy())
172-
all_labels.append(y.cpu().numpy())
172+
all_labels.append(y.cpu().numpy().ravel())
173173
all_features = np.vstack(all_features)
174-
all_labels = np.hstack(all_labels)
174+
all_labels = np.concatenate(all_labels, axis=0)
175175
n_samples, n_features = all_features.shape
176176
n_components = min(2, n_samples, n_features) # adapt automatically
177177
tsne = TSNE(n_components=n_components, random_state=42, perplexity=max(1, min(30, n_samples - 1)))
@@ -302,8 +302,6 @@ def save_model_summary(model_, filename="Model_Summary.txt"):
302302
visualize_tsne_custom(model, embedder, test_texts, test_labels)
303303
print("Running visualize_feature_importance...")
304304
visualize_feature_importance(input_dim)
305-
print("Running plot_loss_landscape_3d...")
306-
plot_loss_landscape_3d(model, dataloader, criterion)
307305
print("Saving model state dict...")
308306
save_model_state_dict(model)
309307
print("Generating model visualization...")
@@ -312,5 +310,7 @@ def save_model_summary(model_, filename="Model_Summary.txt"):
312310
save_graph(model)
313311
print("Saving model summary...")
314312
save_model_summary(model)
315-
313+
print("Running plot_loss_landscape_3d...")
314+
model_cpu = model.to("cpu")
315+
plot_loss_landscape_3d(model_cpu, dataloader, criterion)
316316
print("All visualizations completed. Files saved in 'data/' directory.")

0 commit comments

Comments
 (0)