Skip to content

Commit

Permalink
Merge pull request #9 from k4rimDev/development
Browse files Browse the repository at this point in the history
💄 Added visualization module
  • Loading branch information
k4rimDev authored Aug 9, 2024
2 parents d8764be + 48b1254 commit e11cc7e
Show file tree
Hide file tree
Showing 10 changed files with 651 additions and 4 deletions.
21 changes: 21 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -69,13 +69,33 @@ y = pd.Series([0, 1, 0])
X_train, X_test, y_train, y_test = preprocess_data(X, y, test_size=0.2, random_state=42)
```

### Visualization
Visualization functions can be used to generate plots of model performance:
```py
from random_forest_package.visualizer import ModelVisualizer

# Initialize the visualizer
visualizer = ModelVisualizer(rf_model)

# Plot confusion matrix
visualizer.plot_confusion_matrix(X_test, y_test)

# Plot ROC curve
visualizer.plot_roc_curve(X_test, y_test)

# Plot precision-recall curve
visualizer.plot_precision_recall_curve(X_test, y_test)
```

## Custom Exceptions
This package provides custom exceptions for better error handling:

* `ModelCreationError`: Raised when there is an error creating the random forest model.
* `PreprocessingError`: Raised when there is an error during data preprocessing.
* `TrainingError`: Raised when there is an error during model training.
* `EvaluationError`: Raised when there is an error during model evaluation.
* `VisualizationError`: Raised when there is an error during visualization.


Example of handling a custom exception:

Expand Down Expand Up @@ -117,6 +137,7 @@ random_forest_package/
│ ├── trainer.py # Contains classes for training models
│ ├── evaluator.py # Contains classes for evaluating models
│ ├── utils.py # Utility functions or classes
│ ├── visualizer.py # Utility visualize cases
│ └── exceptions.py # Custom exceptions
├── tests/
Expand Down
442 changes: 441 additions & 1 deletion poetry.lock

Large diffs are not rendered by default.

4 changes: 3 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "random-forest-package"
version = "0.1.3"
version = "0.1.4"
description = "A Python package to facilitate random forest modeling."
authors = ["Karim Mirzaguliyev <[email protected]>"]
readme = "README.md"
Expand All @@ -13,6 +13,8 @@ pandas = "^2.2.2"
numpy = "^2.0.1"
flake8 = "^7.1.1"
lint = "^1.2.1"
matplotlib = "^3.9.1.post1"
seaborn = "^0.13.2"


[tool.poetry.group.dev.dependencies]
Expand Down
Binary file not shown.
Binary file not shown.
6 changes: 6 additions & 0 deletions random_forest_package/random_forest_package/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,3 +29,9 @@ class EvaluationError(RandomForestPackageError):
def __init__(self, message="Error during model evaluation"):
self.message = message
super().__init__(self.message)

class VisualizationError(RandomForestPackageError):
"""Raised when there is an error during visualization."""
def __init__(self, message="Error during visualization"):
self.message = message
super().__init__(self.message)
62 changes: 62 additions & 0 deletions random_forest_package/random_forest_package/visualizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import confusion_matrix, roc_curve, auc, precision_recall_curve

from random_forest_package.exceptions import VisualizationError


class ModelVisualizer:
def __init__(self, model):
self.model = model

def _extracted_from_plot_precision_recall_curve(self, arg0, arg1, arg2):
plt.xlabel(arg0)
plt.ylabel(arg1)
plt.title(arg2)

def plot_confusion_matrix(self, X, y, normalize=False):
try:
y_pred = self.model.predict(X)
cm = confusion_matrix(y, y_pred, normalize='true' if normalize else None)
sns.heatmap(cm, annot=True, fmt='.2f' if normalize else 'd', cmap='Blues')
self._extracted_from_plot_precision_recall_curve(
'Predicted', 'True', 'Confusion Matrix'
)
plt.show()
except Exception as e:
raise VisualizationError(f"Error plotting confusion matrix: {e}") from e

def plot_roc_curve(self, X, y):
try:
y_pred_proba = self.model.predict_proba(X)[:, 1]
fpr, tpr, _ = roc_curve(y, y_pred_proba)
roc_auc = auc(fpr, tpr)

plt.figure()
plt.plot(fpr, tpr, color='darkorange', lw=2, label=f'ROC curve (area = {roc_auc:.2f})')
plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
self._extracted_from_plot_precision_recall_curve(
'False Positive Rate',
'True Positive Rate',
'Receiver Operating Characteristic',
)
plt.legend(loc="lower right")
plt.show()
except Exception as e:
raise VisualizationError(f"Error plotting ROC curve: {e}") from e

def plot_precision_recall_curve(self, X, y):
try:
y_pred_proba = self.model.predict_proba(X)[:, 1]
precision, recall, _ = precision_recall_curve(y, y_pred_proba)

plt.figure()
plt.plot(recall, precision, color='b', lw=2)
self._extracted_from_plot_precision_recall_curve(
'Recall', 'Precision', 'Precision-Recall Curve'
)
plt.show()
except Exception as e:
raise VisualizationError(f"Error plotting precision-recall curve: {e}") from e
Binary file not shown.
114 changes: 114 additions & 0 deletions random_forest_package/tests/test_visualizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
import pytest
import numpy as np
from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
from matplotlib import pyplot as plt

from random_forest_package.visualizer import ModelVisualizer
from random_forest_package.exceptions import VisualizationError


# Fixture to create a simple classification dataset
@pytest.fixture(scope='module')
def classification_data():
X, y = make_classification(n_samples=100, n_features=20, n_classes=2, random_state=42)
return train_test_split(X, y, test_size=0.3, random_state=42)


# Fixture to create a trained RandomForestClassifierModel
@pytest.fixture(scope='module')
def trained_classifier(classification_data):
X_train, X_test, y_train, y_test = classification_data
model = RandomForestClassifier(random_state=42)
model.fit(X_train, y_train)
return model, X_test, y_test


# Tests for plot_confusion_matrix
def test_plot_confusion_matrix_normal(trained_classifier):
model, X_test, y_test = trained_classifier
visualizer = ModelVisualizer(model)

try:
visualizer.plot_confusion_matrix(X_test, y_test)
plt.close()
except Exception as e:
pytest.fail(f"Unexpected error: {e}")


def test_plot_confusion_matrix_with_normalization(trained_classifier):
model, X_test, y_test = trained_classifier
visualizer = ModelVisualizer(model)

try:
visualizer.plot_confusion_matrix(X_test, y_test, normalize=True)
plt.close()
except Exception as e:
pytest.fail(f"Unexpected error: {e}")


def test_plot_confusion_matrix_with_invalid_input(trained_classifier):
model, _, _ = trained_classifier
visualizer = ModelVisualizer(model)

with pytest.raises(VisualizationError):
visualizer.plot_confusion_matrix(None, None)


# Tests for plot_roc_curve
def test_plot_roc_curve_normal(trained_classifier):
model, X_test, y_test = trained_classifier
visualizer = ModelVisualizer(model)

try:
visualizer.plot_roc_curve(X_test, y_test)
plt.close()
except Exception as e:
pytest.fail(f"Unexpected error: {e}")


def test_plot_roc_curve_with_invalid_input(trained_classifier):
model, _, _ = trained_classifier
visualizer = ModelVisualizer(model)

with pytest.raises(VisualizationError):
visualizer.plot_roc_curve(None, None)


# Tests for plot_precision_recall_curve
def test_plot_precision_recall_curve_normal(trained_classifier):
model, X_test, y_test = trained_classifier
visualizer = ModelVisualizer(model)

try:
visualizer.plot_precision_recall_curve(X_test, y_test)
plt.close()
except Exception as e:
pytest.fail(f"Unexpected error: {e}")


def test_plot_precision_recall_curve_with_invalid_input(trained_classifier):
model, _, _ = trained_classifier
visualizer = ModelVisualizer(model)

with pytest.raises(VisualizationError):
visualizer.plot_precision_recall_curve(None, None)


def test_plot_precision_recall_curve_with_single_class(classification_data):
X_train, X_test, y_train, y_test = classification_data
y_train_single_class = np.zeros_like(y_train)

model = RandomForestClassifier(random_state=42)
model.fit(X_train, y_train_single_class)

visualizer = ModelVisualizer(model)

try:
visualizer.plot_precision_recall_curve(X_test, y_test)
plt.close()
except VisualizationError:
pass # Expected outcome
except Exception as e:
pytest.fail(f"Unexpected error: {e}")
6 changes: 4 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,16 @@

setup(
name='random_forest_package',
version='0.1',
version='0.1.4',
packages=find_packages(),
install_requires=[
'scikit-learn',
'numpy',
'pandas',
'flake8',
'lint'
'lint',
'matplotlib',
'seaborn'
],
author='Karim Mirzaguliyev',
author_email='[email protected]',
Expand Down

0 comments on commit e11cc7e

Please sign in to comment.