Skip to content

Commit e11cc7e

Browse files
authored
Merge pull request #9 from k4rimDev/development
💄 Added visualization module
2 parents d8764be + 48b1254 commit e11cc7e

10 files changed

Lines changed: 651 additions & 4 deletions

File tree

README.md

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,13 +69,33 @@ y = pd.Series([0, 1, 0])
6969
X_train, X_test, y_train, y_test = preprocess_data(X, y, test_size=0.2, random_state=42)
7070
```
7171

72+
### Visualization
73+
Visualization functions can be used to generate plots of model performance:
74+
```py
75+
from random_forest_package.visualizer import ModelVisualizer
76+
77+
# Initialize the visualizer
78+
visualizer = ModelVisualizer(rf_model)
79+
80+
# Plot confusion matrix
81+
visualizer.plot_confusion_matrix(X_test, y_test)
82+
83+
# Plot ROC curve
84+
visualizer.plot_roc_curve(X_test, y_test)
85+
86+
# Plot precision-recall curve
87+
visualizer.plot_precision_recall_curve(X_test, y_test)
88+
```
89+
7290
## Custom Exceptions
7391
This package provides custom exceptions for better error handling:
7492

7593
* `ModelCreationError`: Raised when there is an error creating the random forest model.
7694
* `PreprocessingError`: Raised when there is an error during data preprocessing.
7795
* `TrainingError`: Raised when there is an error during model training.
7896
* `EvaluationError`: Raised when there is an error during model evaluation.
97+
* `VisualizationError`: Raised when there is an error during visualization.
98+
7999

80100
Example of handling a custom exception:
81101

@@ -117,6 +137,7 @@ random_forest_package/
117137
│ ├── trainer.py # Contains classes for training models
118138
│ ├── evaluator.py # Contains classes for evaluating models
119139
│ ├── utils.py # Utility functions or classes
140+
│ ├── visualizer.py # Utility visualize cases
120141
│ └── exceptions.py # Custom exceptions
121142
122143
├── tests/

poetry.lock

Lines changed: 441 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = "random-forest-package"
3-
version = "0.1.3"
3+
version = "0.1.4"
44
description = "A Python package to facilitate random forest modeling."
55
authors = ["Karim Mirzaguliyev <kenanovkenan299@gmail.com>"]
66
readme = "README.md"
@@ -13,6 +13,8 @@ pandas = "^2.2.2"
1313
numpy = "^2.0.1"
1414
flake8 = "^7.1.1"
1515
lint = "^1.2.1"
16+
matplotlib = "^3.9.1.post1"
17+
seaborn = "^0.13.2"
1618

1719

1820
[tool.poetry.group.dev.dependencies]
Binary file not shown.
Binary file not shown.

random_forest_package/random_forest_package/exceptions.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,3 +29,9 @@ class EvaluationError(RandomForestPackageError):
2929
def __init__(self, message="Error during model evaluation"):
3030
self.message = message
3131
super().__init__(self.message)
32+
33+
class VisualizationError(RandomForestPackageError):
34+
"""Raised when there is an error during visualization."""
35+
def __init__(self, message="Error during visualization"):
36+
self.message = message
37+
super().__init__(self.message)
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
import matplotlib.pyplot as plt
2+
import seaborn as sns
3+
from sklearn.metrics import confusion_matrix, roc_curve, auc, precision_recall_curve
4+
5+
from random_forest_package.exceptions import VisualizationError
6+
7+
8+
class ModelVisualizer:
9+
def __init__(self, model):
10+
self.model = model
11+
12+
def _extracted_from_plot_precision_recall_curve(self, arg0, arg1, arg2):
13+
plt.xlabel(arg0)
14+
plt.ylabel(arg1)
15+
plt.title(arg2)
16+
17+
def plot_confusion_matrix(self, X, y, normalize=False):
18+
try:
19+
y_pred = self.model.predict(X)
20+
cm = confusion_matrix(y, y_pred, normalize='true' if normalize else None)
21+
sns.heatmap(cm, annot=True, fmt='.2f' if normalize else 'd', cmap='Blues')
22+
self._extracted_from_plot_precision_recall_curve(
23+
'Predicted', 'True', 'Confusion Matrix'
24+
)
25+
plt.show()
26+
except Exception as e:
27+
raise VisualizationError(f"Error plotting confusion matrix: {e}") from e
28+
29+
def plot_roc_curve(self, X, y):
30+
try:
31+
y_pred_proba = self.model.predict_proba(X)[:, 1]
32+
fpr, tpr, _ = roc_curve(y, y_pred_proba)
33+
roc_auc = auc(fpr, tpr)
34+
35+
plt.figure()
36+
plt.plot(fpr, tpr, color='darkorange', lw=2, label=f'ROC curve (area = {roc_auc:.2f})')
37+
plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
38+
plt.xlim([0.0, 1.0])
39+
plt.ylim([0.0, 1.05])
40+
self._extracted_from_plot_precision_recall_curve(
41+
'False Positive Rate',
42+
'True Positive Rate',
43+
'Receiver Operating Characteristic',
44+
)
45+
plt.legend(loc="lower right")
46+
plt.show()
47+
except Exception as e:
48+
raise VisualizationError(f"Error plotting ROC curve: {e}") from e
49+
50+
def plot_precision_recall_curve(self, X, y):
51+
try:
52+
y_pred_proba = self.model.predict_proba(X)[:, 1]
53+
precision, recall, _ = precision_recall_curve(y, y_pred_proba)
54+
55+
plt.figure()
56+
plt.plot(recall, precision, color='b', lw=2)
57+
self._extracted_from_plot_precision_recall_curve(
58+
'Recall', 'Precision', 'Precision-Recall Curve'
59+
)
60+
plt.show()
61+
except Exception as e:
62+
raise VisualizationError(f"Error plotting precision-recall curve: {e}") from e
Binary file not shown.
Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
import pytest
2+
import numpy as np
3+
from sklearn.datasets import make_classification
4+
from sklearn.model_selection import train_test_split
5+
from sklearn.ensemble import RandomForestClassifier
6+
from matplotlib import pyplot as plt
7+
8+
from random_forest_package.visualizer import ModelVisualizer
9+
from random_forest_package.exceptions import VisualizationError
10+
11+
12+
# Fixture to create a simple classification dataset
13+
@pytest.fixture(scope='module')
14+
def classification_data():
15+
X, y = make_classification(n_samples=100, n_features=20, n_classes=2, random_state=42)
16+
return train_test_split(X, y, test_size=0.3, random_state=42)
17+
18+
19+
# Fixture to create a trained RandomForestClassifierModel
20+
@pytest.fixture(scope='module')
21+
def trained_classifier(classification_data):
22+
X_train, X_test, y_train, y_test = classification_data
23+
model = RandomForestClassifier(random_state=42)
24+
model.fit(X_train, y_train)
25+
return model, X_test, y_test
26+
27+
28+
# Tests for plot_confusion_matrix
29+
def test_plot_confusion_matrix_normal(trained_classifier):
30+
model, X_test, y_test = trained_classifier
31+
visualizer = ModelVisualizer(model)
32+
33+
try:
34+
visualizer.plot_confusion_matrix(X_test, y_test)
35+
plt.close()
36+
except Exception as e:
37+
pytest.fail(f"Unexpected error: {e}")
38+
39+
40+
def test_plot_confusion_matrix_with_normalization(trained_classifier):
41+
model, X_test, y_test = trained_classifier
42+
visualizer = ModelVisualizer(model)
43+
44+
try:
45+
visualizer.plot_confusion_matrix(X_test, y_test, normalize=True)
46+
plt.close()
47+
except Exception as e:
48+
pytest.fail(f"Unexpected error: {e}")
49+
50+
51+
def test_plot_confusion_matrix_with_invalid_input(trained_classifier):
52+
model, _, _ = trained_classifier
53+
visualizer = ModelVisualizer(model)
54+
55+
with pytest.raises(VisualizationError):
56+
visualizer.plot_confusion_matrix(None, None)
57+
58+
59+
# Tests for plot_roc_curve
60+
def test_plot_roc_curve_normal(trained_classifier):
61+
model, X_test, y_test = trained_classifier
62+
visualizer = ModelVisualizer(model)
63+
64+
try:
65+
visualizer.plot_roc_curve(X_test, y_test)
66+
plt.close()
67+
except Exception as e:
68+
pytest.fail(f"Unexpected error: {e}")
69+
70+
71+
def test_plot_roc_curve_with_invalid_input(trained_classifier):
72+
model, _, _ = trained_classifier
73+
visualizer = ModelVisualizer(model)
74+
75+
with pytest.raises(VisualizationError):
76+
visualizer.plot_roc_curve(None, None)
77+
78+
79+
# Tests for plot_precision_recall_curve
80+
def test_plot_precision_recall_curve_normal(trained_classifier):
81+
model, X_test, y_test = trained_classifier
82+
visualizer = ModelVisualizer(model)
83+
84+
try:
85+
visualizer.plot_precision_recall_curve(X_test, y_test)
86+
plt.close()
87+
except Exception as e:
88+
pytest.fail(f"Unexpected error: {e}")
89+
90+
91+
def test_plot_precision_recall_curve_with_invalid_input(trained_classifier):
92+
model, _, _ = trained_classifier
93+
visualizer = ModelVisualizer(model)
94+
95+
with pytest.raises(VisualizationError):
96+
visualizer.plot_precision_recall_curve(None, None)
97+
98+
99+
def test_plot_precision_recall_curve_with_single_class(classification_data):
100+
X_train, X_test, y_train, y_test = classification_data
101+
y_train_single_class = np.zeros_like(y_train)
102+
103+
model = RandomForestClassifier(random_state=42)
104+
model.fit(X_train, y_train_single_class)
105+
106+
visualizer = ModelVisualizer(model)
107+
108+
try:
109+
visualizer.plot_precision_recall_curve(X_test, y_test)
110+
plt.close()
111+
except VisualizationError:
112+
pass # Expected outcome
113+
except Exception as e:
114+
pytest.fail(f"Unexpected error: {e}")

setup.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,16 @@
33

44
setup(
55
name='random_forest_package',
6-
version='0.1',
6+
version='0.1.4',
77
packages=find_packages(),
88
install_requires=[
99
'scikit-learn',
1010
'numpy',
1111
'pandas',
1212
'flake8',
13-
'lint'
13+
'lint',
14+
'matplotlib',
15+
'seaborn'
1416
],
1517
author='Karim Mirzaguliyev',
1618
author_email='karimmirzaguliyev@gmail.com',

0 commit comments

Comments
 (0)