Skip to content

Commit

Permalink
resolve some comments
Browse files Browse the repository at this point in the history
  • Loading branch information
gieljnssns committed Jan 29, 2024
1 parent 1186bed commit a5a3361
Showing 1 changed file with 15 additions and 20 deletions.
35 changes: 15 additions & 20 deletions src/emhass/csv_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
import pathlib
import time
from typing import Tuple
import warnings

import pandas as pd
import numpy as np

Expand All @@ -13,26 +15,26 @@
from sklearn.model_selection import train_test_split
from sklearn.neighbors import KNeighborsRegressor

import warnings
warnings.filterwarnings("ignore", category=DeprecationWarning)

warnings.filterwarnings("ignore", category=DeprecationWarning)

class CsvPredictor:
r"""
A forecaster class using machine learning models.
This class uses the `skforecast` module and the machine learning models are from `scikit-learn`.
This class uses the `sklearn` module and the machine learning models are from `scikit-learn`.
It exposes one main method:
- `predict`: to obtain a forecast from a csv file.
"""
def __init__(self, csv_file: str, independent_variables: list, dependent_variable: str, sklearn_model: str, new_values:list, root: str,
logger: logging.Logger) -> None:
logger: logging.Logger) -> None:
r"""Define constructor for the forecast class.
:param csv_file: The name of the csv file to retrieve data from. \
Example: `prediction.csv`.
Example: `input_train_data.csv`.
:type csv_file: str
:param independent_variables: A list of independent variables. \
Example: [`solar`, `degree_days`].
Expand Down Expand Up @@ -60,7 +62,6 @@ def __init__(self, csv_file: str, independent_variables: list, dependent_variabl
self.logger = logger
self.is_tuned = False


def load_data(self) -> pd.DataFrame:
"""Load the data."""
filename_path = pathlib.Path(self.root) / self.csv_file
Expand All @@ -69,18 +70,16 @@ def load_data(self) -> pd.DataFrame:
data = pd.read_csv(inp)
else:
self.logger.error("The cvs file was not found.")
raise ValueError(
f"The CSV file "+ self.csv_file +" was not found."
)
raise ValueError("The CSV file " + self.csv_file + " was not found.")

required_columns = self.independent_variables

if not set(required_columns).issubset(data.columns):
raise ValueError(
f"CSV file should contain the following columns: {', '.join(required_columns)}"
)
return data

def prepare_data(self, data) -> Tuple[np.ndarray, np.ndarray]:
"""
Prepare the data.
Expand All @@ -94,10 +93,10 @@ def prepare_data(self, data) -> Tuple[np.ndarray, np.ndarray]:
X = data[self.independent_variables].values
y = data[self.dependent_variable].values
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

return X_train, y_train


def predict(self) -> np.ndarray:
r"""The predict method to generate a forecast from a csv file.
Expand All @@ -109,7 +108,7 @@ def predict(self) -> np.ndarray:
data = self.load_data()
if data is not None:
X, y = self.prepare_data(data)

if self.sklearn_model == 'LinearRegression':
base_model = LinearRegression()
elif self.sklearn_model == 'ElasticNet':
Expand All @@ -127,9 +126,5 @@ def predict(self) -> np.ndarray:
self.logger.info(f"Elapsed time for model fit: {time.time() - start_time}")
new_values = np.array([self.new_values])
prediction = self.forecaster.predict(new_values)

return prediction




0 comments on commit a5a3361

Please sign in to comment.