Skip to content

Commit

Permalink
Passing prediction column from handler to the model .py files.
Browse files Browse the repository at this point in the history
With this we won't have to rely on the last column always being
the prediction column.
  • Loading branch information
Jineet Desai committed Oct 17, 2023
1 parent 80608ba commit 57990df
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 16 deletions.
8 changes: 8 additions & 0 deletions evadb/executor/create_function_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,10 @@ def handle_sklearn_function(self):
self.node.metadata.append(
FunctionMetadataCatalogEntry("model_path", model_path)
)
# Pass the prediction column name to sklearn.py
self.node.metadata.append(
FunctionMetadataCatalogEntry("predict_col", arg_map["predict"])
)

impl_path = Path(f"{self.function_dir}/sklearn.py").absolute().as_posix()
io_list = self._resolve_function_io(None)
Expand Down Expand Up @@ -205,6 +209,10 @@ def handle_xgboost_function(self):
self.node.metadata.append(
FunctionMetadataCatalogEntry("model_path", model_path)
)
# Pass the prediction column to xgboost.py.
self.node.metadata.append(
FunctionMetadataCatalogEntry("predict_col", arg_map["predict"])
)

impl_path = Path(f"{self.function_dir}/xgboost.py").absolute().as_posix()
io_list = self._resolve_function_io(None)
Expand Down
16 changes: 8 additions & 8 deletions evadb/functions/sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,21 +25,21 @@ class GenericSklearnModel(AbstractFunction):
def name(self) -> str:
return "GenericSklearnModel"

def setup(self, model_path: str, **kwargs):
def setup(self, model_path: str, predict_col: str, **kwargs):
try_to_import_sklearn()

self.model = pickle.load(open(model_path, "rb"))
self.predict_col = predict_col

def forward(self, frames: pd.DataFrame) -> pd.DataFrame:
# The last column is the predictor variable column. Hence we do not
# pass that column in the predict method for sklearn.
predictions = self.model.predict(frames.iloc[:, :-1])
# Do not pass the prediction column in the predict method for sklearn.
frames.drop([self.predict_col], axis=1, inplace=True)
predictions = self.model.predict(frames)
predict_df = pd.DataFrame(predictions)
# We need to rename the column of the output dataframe. For this we
# shall rename it to the column name same as that of the last column of
# frames. This is because the last column of frames corresponds to the
# variable we want to predict.
predict_df.rename(columns={0: frames.columns[-1]}, inplace=True)
# shall rename it to the column name same as that of the predict column
# passed in the training frames in EVA query.
predict_df.rename(columns={0: self.predict_col}, inplace=True)
return predict_df

def to_device(self, device: str):
Expand Down
17 changes: 9 additions & 8 deletions evadb/functions/xgboost.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,21 +25,22 @@ class GenericXGBoostModel(AbstractFunction):
def name(self) -> str:
return "GenericXGBoostModel"

def setup(self, model_path: str, **kwargs):
def setup(self, model_path: str, predict_col: str, **kwargs):
try_to_import_xgboost()

self.model = pickle.load(open(model_path, "rb"))
self.predict_col = predict_col

def forward(self, frames: pd.DataFrame) -> pd.DataFrame:
# Last column is the value to predict, hence don't pass that to the
# predict method.
predictions = self.model.predict(frames.iloc[:, :-1])
# We do not pass the prediction column to the predict method of XGBoost
# AutoML.
frames.drop([self.predict_col], axis=1, inplace=True)
predictions = self.model.predict(frames)
predict_df = pd.DataFrame(predictions)
# We need to rename the column of the output dataframe. For this we
# shall rename it to the column name same as that of the last column of
# frames. This is because the last column of frames corresponds to the
# variable we want to predict.
predict_df.rename(columns={0: frames.columns[-1]}, inplace=True)
# shall rename it to the column name same as that of the predict column
# passed to EVA query.
predict_df.rename(columns={0: self.predict_col}, inplace=True)
return predict_df

def to_device(self, device: str):
Expand Down

0 comments on commit 57990df

Please sign in to comment.