Skip to content

Commit

Permalink
add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
gieljnssns committed Apr 17, 2024
1 parent 6095a2c commit 8fd7fbf
Show file tree
Hide file tree
Showing 3 changed files with 488 additions and 78 deletions.
109 changes: 77 additions & 32 deletions src/emhass/command_line.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,17 +240,23 @@ def set_input_data_dict(
return False
df_input_data = rh.df_final.copy()

elif set_type == "regressor-model-fit":
elif set_type == "regressor-model-fit" or set_type == "regressor-model-predict":

df_input_data_dayahead = None
df_input_data, df_input_data_dayahead = None, None
P_PV_forecast, P_load_forecast = None, None
params = json.loads(params)
days_list = None
csv_file = params["passed_data"]["csv_file"]
features = params["passed_data"]["features"]
target = params["passed_data"]["target"]
timestamp = params["passed_data"]["timestamp"]
filename_path = pathlib.Path(base_path) / csv_file
if get_data_from_file:
base_path = base_path + "/data"
filename_path = pathlib.Path(base_path) / csv_file

else:
filename_path = pathlib.Path(base_path) / csv_file

if filename_path.is_file():
df_input_data = pd.read_csv(filename_path, parse_dates=True)

Expand All @@ -266,21 +272,16 @@ def set_input_data_dict(
if not set(required_columns).issubset(df_input_data.columns):
logger.error("The cvs file does not contain the required columns.")
raise ValueError(
f"CSV file should contain the following columns: {', '.join(required_columns)}"
f"CSV file should contain the following columns: {', '.join(required_columns)}",
)
elif set_type == "regressor-model-predict":
df_input_data, df_input_data_dayahead = None, None
P_PV_forecast, P_load_forecast = None, None
days_list = None
params = json.loads(params)

elif set_type == "publish-data":
df_input_data, df_input_data_dayahead = None, None
P_PV_forecast, P_load_forecast = None, None
days_list = None
else:
logger.error(
"The passed action argument and hence the set_type parameter for setup is not valid"
"The passed action argument and hence the set_type parameter for setup is not valid",
)
df_input_data, df_input_data_dayahead = None, None
P_PV_forecast, P_load_forecast = None, None
Expand Down Expand Up @@ -541,7 +542,7 @@ def forecast_model_predict(
mlf = pickle.load(inp)
else:
logger.error(
"The ML forecaster file was not found, please run a model fit method before this predict method"
"The ML forecaster file was not found, please run a model fit method before this predict method",
)
return
# Make predictions
Expand Down Expand Up @@ -629,7 +630,7 @@ def forecast_model_tune(
mlf = pickle.load(inp)
else:
logger.error(
"The ML forecaster file was not found, please run a model fit method before this tune method"
"The ML forecaster file was not found, please run a model fit method before this tune method",
)
return None, None
# Tune the model
Expand All @@ -643,7 +644,9 @@ def forecast_model_tune(


def regressor_model_fit(
input_data_dict: dict, logger: logging.Logger, debug: Optional[bool] = False
input_data_dict: dict,
logger: logging.Logger,
debug: Optional[bool] = False,
) -> None:
"""Perform a forecast model fit from training data retrieved from Home Assistant.
Expand All @@ -662,9 +665,16 @@ def regressor_model_fit(
timestamp = input_data_dict["params"]["passed_data"]["timestamp"]
date_features = input_data_dict["params"]["passed_data"]["date_features"]
root = input_data_dict["root"]

# The MLRegressor object
mlr = MLRegressor(
data, model_type, regression_model, features, target, timestamp, logger
data,
model_type,
regression_model,
features,
target,
timestamp,
logger,
)
# Fit the ML model
mlr.fit(date_features=date_features)
Expand All @@ -673,10 +683,14 @@ def regressor_model_fit(
filename = model_type + "_mlr.pkl"
with open(pathlib.Path(root) / filename, "wb") as outp:
pickle.dump(mlr, outp, pickle.HIGHEST_PROTOCOL)
return mlr


def regressor_model_predict(
input_data_dict: dict, logger: logging.Logger, debug: Optional[bool] = False
input_data_dict: dict,
logger: logging.Logger,
debug: Optional[bool] = False,
mlr: Optional[MLRegressor] = None,
) -> None:
"""Perform a prediction from csv file.
Expand All @@ -697,7 +711,7 @@ def regressor_model_predict(
mlr = pickle.load(inp)
else:
logger.error(
"The ML forecaster file was not found, please run a model fit method before this predict method"
"The ML forecaster file was not found, please run a model fit method before this predict method",
)
return
new_values = input_data_dict["params"]["passed_data"]["new_values"]
Expand All @@ -715,14 +729,16 @@ def regressor_model_predict(
]
# Publish prediction
idx = 0
input_data_dict["rh"].post_data(
prediction,
idx,
mlr_predict_entity_id,
mlr_predict_unit_of_measurement,
mlr_predict_friendly_name,
type_var="mlregressor",
)
if not debug:
input_data_dict["rh"].post_data(
prediction,
idx,
mlr_predict_entity_id,
mlr_predict_unit_of_measurement,
mlr_predict_friendly_name,
type_var="mlregressor",
)
return prediction


def publish_data(
Expand Down Expand Up @@ -813,7 +829,7 @@ def publish_data(
if "P_deferrable{}".format(k) not in opt_res_latest.columns:
logger.error(
"P_deferrable{}".format(k)
+ " was not found in results DataFrame. Optimization task may need to be relaunched or it did not converge to a solution."
+ " was not found in results DataFrame. Optimization task may need to be relaunched or it did not converge to a solution.",
)
else:
input_data_dict["rh"].post_data(
Expand All @@ -830,7 +846,7 @@ def publish_data(
if input_data_dict["opt"].optim_conf["set_use_battery"]:
if "P_batt" not in opt_res_latest.columns:
logger.error(
"P_batt was not found in results DataFrame. Optimization task may need to be relaunched or it did not converge to a solution."
"P_batt was not found in results DataFrame. Optimization task may need to be relaunched or it did not converge to a solution.",
)
else:
custom_batt_forecast_id = params["passed_data"]["custom_batt_forecast_id"]
Expand Down Expand Up @@ -886,7 +902,7 @@ def publish_data(
if "optim_status" not in opt_res_latest:
opt_res_latest["optim_status"] = "Optimal"
logger.warning(
"no optim_status in opt_res_latest, run an optimization task first"
"no optim_status in opt_res_latest, run an optimization task first",
)
input_data_dict["rh"].post_data(
opt_res_latest["optim_status"],
Expand Down Expand Up @@ -957,7 +973,9 @@ def main():
naive-mpc-optim, publish-data, forecast-model-fit, forecast-model-predict, forecast-model-tune",
)
parser.add_argument(
"--config", type=str, help="Define path to the config.yaml file"
"--config",
type=str,
help="Define path to the config.yaml file",
)
parser.add_argument(
"--costfun",
Expand All @@ -984,7 +1002,10 @@ def main():
help="Pass runtime optimization parameters as dictionnary",
)
parser.add_argument(
"--debug", type=strtobool, default="False", help="Use True for testing purposes"
"--debug",
type=strtobool,
default="False",
help="Use True for testing purposes",
)
args = parser.parse_args()
# The path to the configuration files
Expand All @@ -995,12 +1016,14 @@ def main():
# Additionnal argument
try:
parser.add_argument(
"--version", action="version", version="%(prog)s " + version("emhass")
"--version",
action="version",
version="%(prog)s " + version("emhass"),
)
args = parser.parse_args()
except Exception:
logger.info(
"Version not found for emhass package. Or importlib exited with PackageNotFoundError."
"Version not found for emhass package. Or importlib exited with PackageNotFoundError.",
)
# Setup parameters
input_data_dict = set_input_data_dict(
Expand Down Expand Up @@ -1040,7 +1063,25 @@ def main():
else:
mlf = None
df_pred_optim, mlf = forecast_model_tune(
input_data_dict, logger, debug=args.debug, mlf=mlf
input_data_dict,
logger,
debug=args.debug,
mlf=mlf,
)
opt_res = None
elif args.action == "regressor-model-fit":
mlr = regressor_model_fit(input_data_dict, logger, debug=args.debug)
opt_res = None
elif args.action == "regressor-model-predict":
if args.debug:
mlr = regressor_model_fit(input_data_dict, logger, debug=args.debug)
else:
mlr = None
prediction = regressor_model_predict(
input_data_dict,
logger,
debug=args.debug,
mlr=mlr,
)
opt_res = None
elif args.action == "publish-data":
Expand All @@ -1063,6 +1104,10 @@ def main():
return df_fit_pred, df_fit_pred_backtest, mlf
elif args.action == "forecast-model-predict":
return df_pred
elif args.action == "regressor-model-fit":
return mlr
elif args.action == "regressor-model-predict":
return prediction
elif args.action == "forecast-model-tune":
return df_pred_optim, mlf

Expand Down
Loading

0 comments on commit 8fd7fbf

Please sign in to comment.