-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmutated_validation.py
More file actions
109 lines (90 loc) · 3.24 KB
/
mutated_validation.py
File metadata and controls
109 lines (90 loc) · 3.24 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
"""implementation of Mutation Validation Score"""
from enum import Enum
from typing import Union
from dataclasses import dataclass
import random
from sklearn.metrics import (
accuracy_score,
precision_score,
roc_auc_score,
recall_score
)
from imblearn.metrics import geometric_mean_score
import numpy as np
from mutate_labels import MutatedValidation
random.seed(42)
np.random.seed(42)
class EvaluationMetricsType(str, Enum):
ACCURACY = "accuracy"
PRECISION = "precision"
RECALL = "recall"
ROC_AUC = "roc_auc"
G_MEAN = "g_mean"
class TrainRun(int, Enum):
ORIGINAL = 0
MUTATED = 1
@dataclass
class EvaluationMetric:
predicted_labels: np.ndarray
original_labels: np.ndarray
metric_score: Union[float, int]
def get_metric_score(
predicted_labels: np.ndarray, original_labels: np.ndarray, metric: str
) -> EvaluationMetric:
score = ""
if metric == EvaluationMetricsType.ACCURACY:
score = accuracy_score(original_labels, predicted_labels)
if metric == EvaluationMetricsType.RECALL:
score = recall_score(original_labels, predicted_labels)
if metric == EvaluationMetricsType.PRECISION:
score = precision_score(original_labels, predicted_labels)
if metric == EvaluationMetricsType.ROC_AUC:
score = roc_auc_score(original_labels, predicted_labels)
if metric == EvaluationMetricsType.G_MEAN:
score = geometric_mean_score(original_labels, predicted_labels)
return EvaluationMetric(
predicted_labels=predicted_labels,
original_labels=original_labels,
metric_score=score, # type: ignore
)
@dataclass
class MutatedValidationScore:
mutated_labels: MutatedValidation
mutate: np.ndarray
original_training_predicted_labels: np.ndarray
mutated_training_predicted_labels: np.ndarray
evaluation_metric: str
@property
def mutated_training_metric_score_based_original_labels(self):
return get_metric_score(
predicted_labels=self.mutated_training_predicted_labels,
original_labels=self.mutated_labels.labels,
metric=self.evaluation_metric,
)
@property
def original_training_metric_score_based_on_original_labels(self):
return get_metric_score(
predicted_labels=self.original_training_predicted_labels,
original_labels=self.mutated_labels.labels,
metric=self.evaluation_metric,
)
@property
def mutated_training_metric_score_based_mutated_labels(self):
return get_metric_score(
predicted_labels=self.mutated_training_predicted_labels,
original_labels=self.mutate,
metric=self.evaluation_metric,
)
@property
def get_mv_score(self):
return (
(
(1 - 2 * self.mutated_labels.perturbation_ratio)
* self.mutated_training_metric_score_based_original_labels.metric_score
)
+ (
self.original_training_metric_score_based_on_original_labels.metric_score
- self.mutated_training_metric_score_based_mutated_labels.metric_score
)
+ self.mutated_labels.perturbation_ratio
)