Skip to content

Commit 6a86e14

Browse files
Enhance visualization functionality and update data loading paths in Study_Models and Test_Model scripts
Enhance visualization process and model saving in Study_Models Enhance file loading in Test_Model to include error handling and update validation pattern
1 parent 04eea5a commit 6a86e14

File tree

4 files changed

+22
-3
lines changed

4 files changed

+22
-3
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -208,3 +208,4 @@ __marimo__/
208208
/models/
209209
/cache/
210210
/data/
211+
/*Data_Visualization/

.idea/VulnScan.iml

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

tools/Study_Models.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -286,8 +286,11 @@ def save_model_summary(model_, filename="Model_Summary.txt"):
286286
criterion = nn.BCEWithLogitsLoss()
287287

288288
# Run all visualizations
289+
print("Running visualize_weight_distribution...")
289290
visualize_weight_distribution(model)
291+
print("Running visualize_activations...")
290292
visualize_activations(model, sample_input)
293+
print("Preparing texts and labels for t-SNE custom visualization...")
291294
texts = [
292295
# Non-sensitive (0)
293296
"I need to buy milk and bread from the grocery store.",
@@ -300,14 +303,23 @@ def save_model_summary(model_, filename="Model_Summary.txt"):
300303
"My social security number is 123-45-6789."
301304
]
302305
labels = [0, 0, 0, 1, 1, 1]
306+
print("Loading SentenceTransformer embedder...")
303307
embedder = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
308+
print("Running visualize_tsne...")
304309
visualize_tsne(model, dataloader)
310+
print("Running visualize_tsne_custom...")
305311
visualize_tsne_custom(model, embedder, texts, labels)
312+
print("Running visualize_feature_importance...")
306313
visualize_feature_importance(input_dim)
314+
print("Running plot_loss_landscape_3d...")
307315
plot_loss_landscape_3d(model, dataloader, criterion)
316+
print("Saving model state dict...")
308317
save_model_state_dict(model)
318+
print("Generating model visualization...")
309319
generate_model_visualization(model, input_dim)
320+
print("Saving graph...")
310321
save_graph(model)
322+
print("Saving model summary...")
311323
save_model_summary(model)
312324

313325
print("All visualizations completed. Files saved in 'data/' directory.")

tools/Test_Model.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import sys
2+
13
import torch
24
from sentence_transformers import SentenceTransformer
35
from vulnscan import SimpleNN
@@ -15,6 +17,9 @@
1517
def load_embeddings(folder_path, pattern):
1618
"""Load all .pt files matching pattern and concatenate embeddings and labels"""
1719
files = sorted(glob.glob(os.path.join(folder_path, pattern)))
20+
print("Found files:", files)
21+
if not files:
22+
sys.exit(f"No files found in {folder_path} matching {pattern}")
1823
all_embeddings = []
1924
all_labels = []
2025
for f in files:
@@ -27,18 +32,18 @@ def load_embeddings(folder_path, pattern):
2732

2833

2934
# Example paths
30-
cache_dir = f"cache/{NAME}/round_{ROUND}/embeddings"
35+
cache_dir = f"../cache/{NAME}/round_{ROUND}/embeddings"
3136

3237
# Load all train/test/val embeddings
3338
train_embeddings, train_labels = load_embeddings(cache_dir, "train_*.pt")
3439
test_embeddings, test_labels = load_embeddings(cache_dir, "test_*.pt")
35-
val_embeddings, val_labels = load_embeddings(cache_dir, "val_*.pt")
40+
val_embeddings, val_labels = load_embeddings(cache_dir, "validation_*.pt")
3641

3742
# Initialize model
3843
input_dim = train_embeddings.shape[1]
3944
model = SimpleNN(input_dim=input_dim).to(device)
4045
model.load_state_dict(torch.load(
41-
f"cache/{NAME}/round_{ROUND}/{NAME}_round{ROUND}.pth",
46+
f"../cache/{NAME}/round_{ROUND}/{NAME}_round{ROUND}.pth",
4247
map_location="cpu"
4348
))
4449
model.eval()

0 commit comments

Comments
 (0)