diff --git a/docs/_toc.yml b/docs/_toc.yml index a8639dec3..3b3eeda5e 100644 --- a/docs/_toc.yml +++ b/docs/_toc.yml @@ -88,6 +88,8 @@ parts: title: Model Training with Ludwig - file: source/reference/ai/model-train-sklearn title: Model Training with Sklearn + - file: source/reference/ai/model-train-xgboost + title: Model Training with XGBoost - file: source/reference/ai/model-forecasting title: Time Series Forecasting - file: source/reference/ai/hf diff --git a/docs/source/reference/ai/model-train-xgboost.rst b/docs/source/reference/ai/model-train-xgboost.rst new file mode 100644 index 000000000..b53c87d48 --- /dev/null +++ b/docs/source/reference/ai/model-train-xgboost.rst @@ -0,0 +1,26 @@ +.. _xgboost: + +Model Training with XGBoost +============================ + +1. Installation +--------------- + +To use the `Flaml XGBoost AutoML framework `_, we need to install the extra Flaml dependency in your EvaDB virtual environment. + +.. code-block:: bash + + pip install "flaml[automl]" + +2. Example Query +---------------- + +.. code-block:: sql + + CREATE FUNCTION IF NOT EXISTS PredictRent FROM + ( SELECT number_of_rooms, number_of_bathrooms, days_on_market, rental_price FROM HomeRentals ) + TYPE XGBoost + PREDICT 'rental_price'; + +In the above query, you are creating a new customized function by training a model from the ``HomeRentals`` table using the ``Flaml XGBoost`` framework. +The ``rental_price`` column will be the target column for predication, while the rest columns from the ``SELET`` query are the inputs. diff --git a/evadb/binder/statement_binder.py b/evadb/binder/statement_binder.py index f9087b5be..f1e949941 100644 --- a/evadb/binder/statement_binder.py +++ b/evadb/binder/statement_binder.py @@ -102,7 +102,9 @@ def _bind_create_function_statement(self, node: CreateFunctionStatement): outputs.append(column) else: inputs.append(column) - elif string_comparison_case_insensitive(node.function_type, "sklearn"): + elif string_comparison_case_insensitive( + node.function_type, "sklearn" + ) or string_comparison_case_insensitive(node.function_type, "XGBoost"): assert ( "predict" in arg_map ), f"Creating {node.function_type} functions expects 'predict' metadata." diff --git a/evadb/configuration/constants.py b/evadb/configuration/constants.py index 8a6f95b5c..395f898be 100644 --- a/evadb/configuration/constants.py +++ b/evadb/configuration/constants.py @@ -34,3 +34,4 @@ DEFAULT_TRAIN_TIME_LIMIT = 120 DEFAULT_DOCUMENT_CHUNK_SIZE = 4000 DEFAULT_DOCUMENT_CHUNK_OVERLAP = 200 +DEFAULT_TRAIN_REGRESSION_METRIC = "rmse" diff --git a/evadb/executor/create_function_executor.py b/evadb/executor/create_function_executor.py index 0b4ddbf7c..379157563 100644 --- a/evadb/executor/create_function_executor.py +++ b/evadb/executor/create_function_executor.py @@ -25,6 +25,7 @@ from evadb.catalog.models.function_io_catalog import FunctionIOCatalogEntry from evadb.catalog.models.function_metadata_catalog import FunctionMetadataCatalogEntry from evadb.configuration.constants import ( + DEFAULT_TRAIN_REGRESSION_METRIC, DEFAULT_TRAIN_TIME_LIMIT, EvaDB_INSTALLATION_DIR, ) @@ -44,6 +45,7 @@ try_to_import_statsforecast, try_to_import_torch, try_to_import_ultralytics, + try_to_import_xgboost, ) from evadb.utils.logging_manager import logger @@ -152,6 +154,10 @@ def handle_sklearn_function(self): self.node.metadata.append( FunctionMetadataCatalogEntry("model_path", model_path) ) + # Pass the prediction column name to sklearn.py + self.node.metadata.append( + FunctionMetadataCatalogEntry("predict_col", arg_map["predict"]) + ) impl_path = Path(f"{self.function_dir}/sklearn.py").absolute().as_posix() io_list = self._resolve_function_io(None) @@ -163,6 +169,61 @@ def handle_sklearn_function(self): self.node.metadata, ) + def handle_xgboost_function(self): + """Handle xgboost functions + + We use the Flaml AutoML model for training xgboost models. + """ + try_to_import_xgboost() + + assert ( + len(self.children) == 1 + ), "Create sklearn function expects 1 child, finds {}.".format( + len(self.children) + ) + + aggregated_batch_list = [] + child = self.children[0] + for batch in child.exec(): + aggregated_batch_list.append(batch) + aggregated_batch = Batch.concat(aggregated_batch_list, copy=False) + aggregated_batch.drop_column_alias() + + arg_map = {arg.key: arg.value for arg in self.node.metadata} + from flaml import AutoML + + model = AutoML() + settings = { + "time_budget": arg_map.get("time_limit", DEFAULT_TRAIN_TIME_LIMIT), + "metric": arg_map.get("metric", DEFAULT_TRAIN_REGRESSION_METRIC), + "estimator_list": ["xgboost"], + "task": "regression", + } + model.fit( + dataframe=aggregated_batch.frames, label=arg_map["predict"], **settings + ) + model_path = os.path.join( + self.db.config.get_value("storage", "model_dir"), self.node.name + ) + pickle.dump(model, open(model_path, "wb")) + self.node.metadata.append( + FunctionMetadataCatalogEntry("model_path", model_path) + ) + # Pass the prediction column to xgboost.py. + self.node.metadata.append( + FunctionMetadataCatalogEntry("predict_col", arg_map["predict"]) + ) + + impl_path = Path(f"{self.function_dir}/xgboost.py").absolute().as_posix() + io_list = self._resolve_function_io(None) + return ( + self.node.name, + impl_path, + self.node.function_type, + io_list, + self.node.metadata, + ) + def handle_ultralytics_function(self): """Handle Ultralytics functions""" try_to_import_ultralytics() @@ -516,6 +577,14 @@ def exec(self, *args, **kwargs): io_list, metadata, ) = self.handle_sklearn_function() + elif string_comparison_case_insensitive(self.node.function_type, "XGBoost"): + ( + name, + impl_path, + function_type, + io_list, + metadata, + ) = self.handle_xgboost_function() elif string_comparison_case_insensitive(self.node.function_type, "Forecasting"): ( name, diff --git a/evadb/functions/sklearn.py b/evadb/functions/sklearn.py index ca3676f14..4ab2b0abf 100644 --- a/evadb/functions/sklearn.py +++ b/evadb/functions/sklearn.py @@ -25,21 +25,21 @@ class GenericSklearnModel(AbstractFunction): def name(self) -> str: return "GenericSklearnModel" - def setup(self, model_path: str, **kwargs): + def setup(self, model_path: str, predict_col: str, **kwargs): try_to_import_sklearn() self.model = pickle.load(open(model_path, "rb")) + self.predict_col = predict_col def forward(self, frames: pd.DataFrame) -> pd.DataFrame: - # The last column is the predictor variable column. Hence we do not - # pass that column in the predict method for sklearn. - predictions = self.model.predict(frames.iloc[:, :-1]) + # Do not pass the prediction column in the predict method for sklearn. + frames.drop([self.predict_col], axis=1, inplace=True) + predictions = self.model.predict(frames) predict_df = pd.DataFrame(predictions) # We need to rename the column of the output dataframe. For this we - # shall rename it to the column name same as that of the last column of - # frames. This is because the last column of frames corresponds to the - # variable we want to predict. - predict_df.rename(columns={0: frames.columns[-1]}, inplace=True) + # shall rename it to the column name same as that of the predict column + # passed in the training frames in EVA query. + predict_df.rename(columns={0: self.predict_col}, inplace=True) return predict_df def to_device(self, device: str): diff --git a/evadb/functions/xgboost.py b/evadb/functions/xgboost.py new file mode 100644 index 000000000..063529411 --- /dev/null +++ b/evadb/functions/xgboost.py @@ -0,0 +1,48 @@ +# coding=utf-8 +# Copyright 2018-2023 EvaDB +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import pickle + +import pandas as pd + +from evadb.functions.abstract.abstract_function import AbstractFunction +from evadb.utils.generic_utils import try_to_import_xgboost + + +class GenericXGBoostModel(AbstractFunction): + @property + def name(self) -> str: + return "GenericXGBoostModel" + + def setup(self, model_path: str, predict_col: str, **kwargs): + try_to_import_xgboost() + + self.model = pickle.load(open(model_path, "rb")) + self.predict_col = predict_col + + def forward(self, frames: pd.DataFrame) -> pd.DataFrame: + # We do not pass the prediction column to the predict method of XGBoost + # AutoML. + frames.drop([self.predict_col], axis=1, inplace=True) + predictions = self.model.predict(frames) + predict_df = pd.DataFrame(predictions) + # We need to rename the column of the output dataframe. For this we + # shall rename it to the column name same as that of the predict column + # passed to EVA query. + predict_df.rename(columns={0: self.predict_col}, inplace=True) + return predict_df + + def to_device(self, device: str): + # TODO figure out how to control the GPU for ludwig models + return self diff --git a/evadb/utils/generic_utils.py b/evadb/utils/generic_utils.py index a444fb983..fb6bd9986 100644 --- a/evadb/utils/generic_utils.py +++ b/evadb/utils/generic_utils.py @@ -377,6 +377,25 @@ def is_sklearn_available() -> bool: return False +def try_to_import_xgboost(): + try: + import flaml # noqa: F401 + from flaml import AutoML # noqa: F401 + except ImportError: + raise ValueError( + """Could not import Flaml AutoML. + Please install it with `pip install "flaml[automl]"`.""" + ) + + +def is_xgboost_available() -> bool: + try: + try_to_import_xgboost() + return True + except ValueError: # noqa: E722 + return False + + ############################## ## VISION ############################## diff --git a/setup.py b/setup.py index 9c488c939..a18796d84 100644 --- a/setup.py +++ b/setup.py @@ -120,6 +120,8 @@ def read(path, encoding="utf-8"): sklearn_libs = ["scikit-learn"] +xgboost_libs = ["flaml[automl]"] + forecasting_libs = [ "statsforecast", # MODEL TRAIN AND FINE TUNING "neuralforecast" # MODEL TRAIN AND FINE TUNING @@ -169,9 +171,10 @@ def read(path, encoding="utf-8"): "postgres": postgres_libs, "ludwig": ludwig_libs, "sklearn": sklearn_libs, + "xgboost": xgboost_libs, "forecasting": forecasting_libs, # everything except ray, qdrant, ludwig and postgres. The first three fail on pyhton 3.11. - "dev": dev_libs + vision_libs + document_libs + function_libs + notebook_libs + forecasting_libs + sklearn_libs + imagegen_libs, + "dev": dev_libs + vision_libs + document_libs + function_libs + notebook_libs + forecasting_libs + sklearn_libs + imagegen_libs + xgboost_libs } setup( diff --git a/test/integration_tests/long/test_model_train.py b/test/integration_tests/long/test_model_train.py index 7424ba424..85e508f4d 100644 --- a/test/integration_tests/long/test_model_train.py +++ b/test/integration_tests/long/test_model_train.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import unittest -from test.markers import ludwig_skip_marker, sklearn_skip_marker +from test.markers import ludwig_skip_marker, sklearn_skip_marker, xgboost_skip_marker from test.util import get_evadb_for_testing, shutdown_ray import pytest @@ -95,6 +95,25 @@ def test_sklearn_regression(self): self.assertEqual(len(result.columns), 1) self.assertEqual(len(result), 10) + @xgboost_skip_marker + def test_xgboost_regression(self): + create_predict_function = """ + CREATE FUNCTION IF NOT EXISTS PredictRent FROM + ( SELECT number_of_rooms, number_of_bathrooms, days_on_market, rental_price FROM HomeRentals ) + TYPE XGBoost + PREDICT 'rental_price' + TIME_LIMIT 180 + METRIC 'r2'; + """ + execute_query_fetch_all(self.evadb, create_predict_function) + + predict_query = """ + SELECT PredictRent(number_of_rooms, number_of_bathrooms, days_on_market, rental_price) FROM HomeRentals LIMIT 10; + """ + result = execute_query_fetch_all(self.evadb, predict_query) + self.assertEqual(len(result.columns), 1) + self.assertEqual(len(result), 10) + if __name__ == "__main__": unittest.main() diff --git a/test/markers.py b/test/markers.py index 7d98e5534..3e95c1cff 100644 --- a/test/markers.py +++ b/test/markers.py @@ -27,6 +27,7 @@ is_qdrant_available, is_replicate_available, is_sklearn_available, + is_xgboost_available, ) asyncio_skip_marker = pytest.mark.skipif( @@ -89,6 +90,10 @@ is_sklearn_available() is False, reason="Run only if sklearn is available" ) +xgboost_skip_marker = pytest.mark.skipif( + is_xgboost_available() is False, reason="Run only if xgboost is available" +) + chatgpt_skip_marker = pytest.mark.skip( reason="requires chatgpt", )