-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathevaluation.py
More file actions
54 lines (44 loc) · 1.51 KB
/
evaluation.py
File metadata and controls
54 lines (44 loc) · 1.51 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
from tensorflow.keras.models import load_model
from data_preparation import get_generators
import matplotlib.pyplot as plt
import numpy as np
from sklearn.metrics import confusion_matrix, classification_report
# Set paths
data_dir = r'C:\Users\defo\Documents\dataset'
batch_size = 16
img_size = (224, 224)
model_path = r'C:\Users\defo\Documents\ViT_CancerDetection.h5'
# Load data generators
_, _, test_gen = get_generators(data_dir, batch_size, img_size)
# Load the trained model
model = load_model(model_path)
# Evaluate the model
print("Evaluating model on the test set...")
test_loss, test_acc = model.evaluate(test_gen)
print(f"Test Accuracy: {test_acc:.2f}")
# Generate predictions
preds = model.predict(test_gen)
y_pred = (preds > 0.5).astype(int).flatten()
true_labels = test_gen.classes
# Confusion Matrix
cm = confusion_matrix(true_labels, y_pred)
classes = ['Non-Cancerous', 'Cancerous']
plt.figure(figsize=(6, 6))
plt.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues)
plt.title('Confusion Matrix')
plt.colorbar()
tick_marks = np.arange(2)
plt.xticks(tick_marks, classes, rotation=45)
plt.yticks(tick_marks, classes)
thresh = cm.max() / 2.0
for i, j in np.ndindex(cm.shape):
plt.text(j, i, cm[i, j], horizontalalignment='center',
color='white' if cm[i, j] > thresh else 'black')
plt.tight_layout()
plt.ylabel('True Label')
plt.xlabel('Predicted Label')
plt.show()
# Classification Report
report = classification_report(true_labels, y_pred, target_names=classes)
print("Classification Report:")
print(report)