Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

New Classification Model for Fault Detection in 3W Dataset #64

Open
wants to merge 22 commits into
base: main
Choose a base branch
from

Conversation

yantavares
Copy link

This PR was made by the team Black Oil Pyrates in HACKTUDO.

Description:

This PR introduces a new ClassificationModel class to the repository, which enables the training and evaluation of custom machine learning models to classify faults in the 3W dataset. The ClassificationModel allows for flexible model integration, supporting different algorithms for the identification and classification of class faults.


Code Overview:

  • ClassificationModel class:
    • Attributes:
      • model: Allows the user to specify a custom classification model (e.g., Random Forest).
      • scaler: A StandardScaler instance used for feature normalization.
    • Methods:
      • load_and_prepare_data: Loads and preprocesses data from parquet files, handling timestamp cleaning and feature extraction.
      • train_model: Trains the provided machine learning model on the training dataset.
      • predict_single_file: Makes predictions on a single file, returning actual and predicted class labels along with the timestamps.
      • plot_predictions_with_timestamp: Visualizes the time-series predictions alongside actual class values using Plotly.
      • complete_analysis: Automates the entire process from data loading to prediction and visualization.
      • compare_models: Compares multiple models, training and testing them on different files, and evaluates their performance using accuracy and F1-score.

Example Usage:

from sklearn.ensemble import RandomForestClassifier
from classification import ClassificationModel

model = RandomForestClassifier(n_estimators=200, max_depth=12, random_state=42, class_weight='balanced')
dataset_path = './dataset/2'

# Initialize the classification model with Random Forest
classification_model = ClassificationModel(dataset_path=dataset_path, model=model)

# Get the list of parquet files
all_files = [os.path.join(dataset_path, f) for f in os.listdir(dataset_path) if f.endswith(".parquet")]

# Run the complete analysis
classification_model.complete_analysis(y_axis='T-TPT', files=all_files, train_size=0.8)

newplot (8)


Dependencies:

  • Pandas: For data loading, cleaning, and manipulation.
  • Scikit-learn: For implementing various machine learning models and performance metrics.
  • Plotly: For generating interactive visualizations of the time-series predictions.

This PR adds a flexible and scalable way to train machine learning models for fault classification in the 3W dataset, with support for custom models like Random Forest and interactive plotting capabilities.

Refactor calculate_tipping_point method to include is_validation and use_val_limit parameters
Refactor test method to include use_val_limit parameter
Refactor plot_SPE method to include is_validation and use_val_limit parameters
Refactor test_pipeline method to include use_val_limit parameter
@yantavares yantavares changed the base branch from dev to main October 25, 2024 17:02
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant