From 2671059e42645f590cbfe2558aa0eaefa300e36c Mon Sep 17 00:00:00 2001 From: Andy Xu Date: Sat, 9 Sep 2023 21:57:56 -0700 Subject: [PATCH] Fix column name related issue for Forecast functions (#1084) Addressing item3 in #1081 * [x] In `evadb/executor/create_function_executor.py`, we rename the input relationship to a [fixed schema](https://nixtla.github.io/statsforecast/docs/getting-started/getting_started_short.html) requested by statsforecast * [x] Rename the output column so it is synced with binder. A temporal fix. We will reconsider the design in #1017 * [x] Update testcases to test the column rename feature. --- data/forecasting/home_sales.csv | 348 ++++++++++++++++++ evadb/executor/create_function_executor.py | 35 +- evadb/functions/forecast.py | 28 +- .../long/test_model_forecasting.py | 43 ++- 4 files changed, 409 insertions(+), 45 deletions(-) create mode 100644 data/forecasting/home_sales.csv diff --git a/data/forecasting/home_sales.csv b/data/forecasting/home_sales.csv new file mode 100644 index 000000000..5cca54415 --- /dev/null +++ b/data/forecasting/home_sales.csv @@ -0,0 +1,348 @@ +saledate,ma,type,bedrooms +30/09/2007,441854,house,2 +31/12/2007,441854,house,2 +31/03/2008,441854,house,2 +30/06/2008,441854,house,2 +30/09/2008,451583,house,2 +31/12/2008,440256,house,2 +31/03/2009,442566,house,2 +30/06/2009,446113,house,2 +30/09/2009,440123,house,2 +31/12/2009,442131,house,2 +31/03/2010,459222,house,2 +30/06/2010,456822,house,2 +30/09/2010,457806,house,2 +31/12/2010,459109,house,2 +31/03/2011,460758,house,2 +30/06/2011,464788,house,2 +30/09/2011,467546,house,2 +31/12/2011,470333,house,2 +31/03/2012,470365,house,2 +30/06/2012,469149,house,2 +30/09/2012,465919,house,2 +31/12/2012,463090,house,2 +31/03/2013,451077,house,2 +30/06/2013,451516,house,2 +30/09/2013,454270,house,2 +31/12/2013,456548,house,2 +31/03/2014,469920,house,2 +30/06/2014,472726,house,2 +30/09/2014,475326,house,2 +31/12/2014,478413,house,2 +31/03/2015,478398,house,2 +30/06/2015,477238,house,2 +30/09/2015,477330,house,2 +31/12/2015,479010,house,2 +31/03/2016,482440,house,2 +30/06/2016,486436,house,2 +30/09/2016,489104,house,2 +31/12/2016,491152,house,2 +31/03/2017,494544,house,2 +30/06/2017,498846,house,2 +30/09/2017,504592,house,2 +31/12/2017,506578,house,2 +31/03/2018,507248,house,2 +30/06/2018,506116,house,2 +30/09/2018,504318,house,2 +31/12/2018,506001,house,2 +31/03/2019,496133,house,2 +30/06/2019,500158,house,2 +30/09/2019,510712,house,2 +31/03/2007,421291,house,3 +30/06/2007,421291,house,3 +30/09/2007,421291,house,3 +31/12/2007,421291,house,3 +31/03/2008,416031,house,3 +30/06/2008,419628,house,3 +30/09/2008,423811,house,3 +31/12/2008,426488,house,3 +31/03/2009,437724,house,3 +30/06/2009,444351,house,3 +30/09/2009,449742,house,3 +31/12/2009,457394,house,3 +31/03/2010,466433,house,3 +30/06/2010,474590,house,3 +30/09/2010,483176,house,3 +31/12/2010,491715,house,3 +31/03/2011,498022,house,3 +30/06/2011,503891,house,3 +30/09/2011,507090,house,3 +31/12/2011,507744,house,3 +31/03/2012,507449,house,3 +30/06/2012,507014,house,3 +30/09/2012,506615,house,3 +31/12/2012,506615,house,3 +31/03/2013,506380,house,3 +30/06/2013,505739,house,3 +30/09/2013,505823,house,3 +31/12/2013,506406,house,3 +31/03/2014,508499,house,3 +30/06/2014,512374,house,3 +30/09/2014,516618,house,3 +31/12/2014,522103,house,3 +31/03/2015,528926,house,3 +30/06/2015,534927,house,3 +30/09/2015,542051,house,3 +31/12/2015,549278,house,3 +31/03/2016,556586,house,3 +30/06/2016,564267,house,3 +30/09/2016,572582,house,3 +31/12/2016,581485,house,3 +31/03/2017,590949,house,3 +30/06/2017,601041,house,3 +30/09/2017,609355,house,3 +31/12/2017,615743,house,3 +31/03/2018,619638,house,3 +30/06/2018,622466,house,3 +30/09/2018,624602,house,3 +31/12/2018,626608,house,3 +31/03/2019,628423,house,3 +30/06/2019,630814,house,3 +30/09/2019,631875,house,3 +31/03/2007,548969,house,4 +30/06/2007,548969,house,4 +30/09/2007,548969,house,4 +31/12/2007,548969,house,4 +31/03/2008,552484,house,4 +30/06/2008,559580,house,4 +30/09/2008,561852,house,4 +31/12/2008,565467,house,4 +31/03/2009,569682,house,4 +30/06/2009,574680,house,4 +30/09/2009,579369,house,4 +31/12/2009,588379,house,4 +31/03/2010,599614,house,4 +30/06/2010,608528,house,4 +30/09/2010,615603,house,4 +31/12/2010,623105,house,4 +31/03/2011,628969,house,4 +30/06/2011,634155,house,4 +30/09/2011,636582,house,4 +31/12/2011,637421,house,4 +31/03/2012,635411,house,4 +30/06/2012,633695,house,4 +30/09/2012,634803,house,4 +31/12/2012,633875,house,4 +31/03/2013,634229,house,4 +30/06/2013,635515,house,4 +30/09/2013,636687,house,4 +31/12/2013,641125,house,4 +31/03/2014,648174,house,4 +30/06/2014,655757,house,4 +30/09/2014,664635,house,4 +31/12/2014,673762,house,4 +31/03/2015,684006,house,4 +30/06/2015,694800,house,4 +30/09/2015,706711,house,4 +31/12/2015,718261,house,4 +31/03/2016,727736,house,4 +30/06/2016,737159,house,4 +30/09/2016,745430,house,4 +31/12/2016,755683,house,4 +31/03/2017,771216,house,4 +30/06/2017,789732,house,4 +30/09/2017,810694,house,4 +31/12/2017,828058,house,4 +31/03/2018,836056,house,4 +30/06/2018,837295,house,4 +30/09/2018,830727,house,4 +31/12/2018,820924,house,4 +31/03/2019,811121,house,4 +30/06/2019,803925,house,4 +30/09/2019,791446,house,4 +30/09/2007,735904,house,5 +31/12/2007,735904,house,5 +31/03/2008,735904,house,5 +30/06/2008,735904,house,5 +30/09/2008,758340,house,5 +31/12/2008,764025,house,5 +31/03/2009,770046,house,5 +30/06/2009,765555,house,5 +30/09/2009,765515,house,5 +31/12/2009,771280,house,5 +31/03/2010,773355,house,5 +30/06/2010,776325,house,5 +30/09/2010,772699,house,5 +31/12/2010,775199,house,5 +31/03/2011,778470,house,5 +30/06/2011,789627,house,5 +30/09/2011,789614,house,5 +31/12/2011,790965,house,5 +31/03/2012,794533,house,5 +30/06/2012,792171,house,5 +30/09/2012,800432,house,5 +31/12/2012,804474,house,5 +31/03/2013,807826,house,5 +30/06/2013,812224,house,5 +30/09/2013,805066,house,5 +31/12/2013,805682,house,5 +31/03/2014,811908,house,5 +30/06/2014,820368,house,5 +30/09/2014,843904,house,5 +31/12/2014,855039,house,5 +31/03/2015,866489,house,5 +30/06/2015,880625,house,5 +30/09/2015,891981,house,5 +31/12/2015,909131,house,5 +31/03/2016,923594,house,5 +30/06/2016,933589,house,5 +30/09/2016,952327,house,5 +31/12/2016,968331,house,5 +31/03/2017,980953,house,5 +30/06/2017,995349,house,5 +30/09/2017,1004117,house,5 +31/12/2017,1010848,house,5 +31/03/2018,1015529,house,5 +30/06/2018,1017752,house,5 +30/09/2018,1007114,house,5 +31/12/2018,1002323,house,5 +31/03/2019,998136,house,5 +30/06/2019,995363,house,5 +30/09/2019,970268,house,5 +31/12/2007,326076,unit,1 +31/03/2008,326076,unit,1 +30/06/2008,326076,unit,1 +30/09/2008,326076,unit,1 +31/12/2008,327321,unit,1 +31/03/2009,324712,unit,1 +30/06/2009,323556,unit,1 +30/09/2009,318922,unit,1 +31/12/2009,316914,unit,1 +31/03/2010,316751,unit,1 +30/06/2010,317711,unit,1 +30/09/2010,318695,unit,1 +31/12/2010,324778,unit,1 +31/03/2011,329856,unit,1 +30/06/2011,333049,unit,1 +30/09/2011,337144,unit,1 +31/12/2011,337400,unit,1 +31/03/2012,339125,unit,1 +30/06/2012,341807,unit,1 +30/09/2012,344793,unit,1 +31/12/2012,347754,unit,1 +31/03/2013,348491,unit,1 +30/06/2013,348512,unit,1 +30/09/2013,347962,unit,1 +31/12/2013,345573,unit,1 +31/03/2014,343298,unit,1 +30/06/2014,341289,unit,1 +30/09/2014,338293,unit,1 +31/12/2014,336520,unit,1 +31/03/2015,334488,unit,1 +30/06/2015,332703,unit,1 +30/09/2015,330278,unit,1 +31/12/2015,328300,unit,1 +31/03/2016,326476,unit,1 +30/06/2016,324725,unit,1 +30/09/2016,325127,unit,1 +31/12/2016,325521,unit,1 +31/03/2017,327870,unit,1 +30/06/2017,330319,unit,1 +30/09/2017,332481,unit,1 +31/12/2017,334804,unit,1 +31/03/2018,336637,unit,1 +30/06/2018,338105,unit,1 +30/09/2018,339220,unit,1 +31/12/2018,339350,unit,1 +31/03/2019,337838,unit,1 +30/06/2019,336551,unit,1 +30/09/2019,335449,unit,1 +30/06/2007,368817,unit,2 +30/09/2007,368817,unit,2 +31/12/2007,368817,unit,2 +31/03/2008,368817,unit,2 +30/06/2008,373482,unit,2 +30/09/2008,377481,unit,2 +31/12/2008,382010,unit,2 +31/03/2009,380810,unit,2 +30/06/2009,385791,unit,2 +30/09/2009,391161,unit,2 +31/12/2009,396448,unit,2 +31/03/2010,402898,unit,2 +30/06/2010,408608,unit,2 +30/09/2010,412509,unit,2 +31/12/2010,415991,unit,2 +31/03/2011,417970,unit,2 +30/06/2011,419777,unit,2 +30/09/2011,421158,unit,2 +31/12/2011,423144,unit,2 +31/03/2012,424673,unit,2 +30/06/2012,424249,unit,2 +30/09/2012,425453,unit,2 +31/12/2012,425922,unit,2 +31/03/2013,425751,unit,2 +30/06/2013,426621,unit,2 +30/09/2013,428398,unit,2 +31/12/2013,428365,unit,2 +31/03/2014,429283,unit,2 +30/06/2014,429361,unit,2 +30/09/2014,428911,unit,2 +31/12/2014,429832,unit,2 +31/03/2015,431567,unit,2 +30/06/2015,432730,unit,2 +30/09/2015,432791,unit,2 +31/12/2015,432801,unit,2 +31/03/2016,431418,unit,2 +30/06/2016,430880,unit,2 +30/09/2016,430654,unit,2 +31/12/2016,430308,unit,2 +31/03/2017,429897,unit,2 +30/06/2017,429059,unit,2 +30/09/2017,428878,unit,2 +31/12/2017,428532,unit,2 +31/03/2018,427856,unit,2 +30/06/2018,427623,unit,2 +30/09/2018,426970,unit,2 +31/12/2018,426936,unit,2 +31/03/2019,426669,unit,2 +30/06/2019,425659,unit,2 +30/09/2019,424412,unit,2 +30/09/2007,518911,unit,3 +31/12/2007,518911,unit,3 +31/03/2008,518911,unit,3 +30/06/2008,518911,unit,3 +30/09/2008,518911,unit,3 +31/12/2008,518911,unit,3 +31/03/2009,518911,unit,3 +30/06/2009,518911,unit,3 +30/09/2009,523285,unit,3 +31/12/2009,522862,unit,3 +31/03/2010,524008,unit,3 +30/06/2010,535063,unit,3 +30/09/2010,538694,unit,3 +31/12/2010,555117,unit,3 +31/03/2011,550851,unit,3 +30/06/2011,547981,unit,3 +30/09/2011,539828,unit,3 +31/12/2011,530987,unit,3 +31/03/2012,540344,unit,3 +30/06/2012,537592,unit,3 +30/09/2012,548326,unit,3 +31/12/2012,555644,unit,3 +31/03/2013,566706,unit,3 +30/06/2013,580696,unit,3 +30/09/2013,581428,unit,3 +31/12/2013,586470,unit,3 +31/03/2014,583883,unit,3 +30/06/2014,583370,unit,3 +30/09/2014,598512,unit,3 +31/12/2014,598812,unit,3 +31/03/2015,599507,unit,3 +30/06/2015,602877,unit,3 +30/09/2015,603343,unit,3 +31/12/2015,612295,unit,3 +31/03/2016,617363,unit,3 +30/06/2016,622045,unit,3 +30/09/2016,616198,unit,3 +31/12/2016,610618,unit,3 +31/03/2017,606935,unit,3 +30/06/2017,605273,unit,3 +30/09/2017,606850,unit,3 +31/12/2017,604413,unit,3 +31/03/2018,604293,unit,3 +30/06/2018,603434,unit,3 +30/09/2018,603281,unit,3 +31/12/2018,601167,unit,3 +31/03/2019,605637,unit,3 +30/06/2019,599339,unit,3 +30/09/2019,597884,unit,3 diff --git a/evadb/executor/create_function_executor.py b/evadb/executor/create_function_executor.py index 7b99458c8..89b1db6f3 100644 --- a/evadb/executor/create_function_executor.py +++ b/evadb/executor/create_function_executor.py @@ -159,12 +159,20 @@ def handle_forecasting_function(self): model_name = arg_map["model"] frequency = arg_map["frequency"] - data = aggregated_batch.frames.rename(columns={arg_map["predict"]: "y"}) + """ + The following rename is needed for statsforecast, which requires the column name to be the following: + - The unique_id (string, int or category) represents an identifier for the series. + - The ds (datestamp) column should be of a format expected by Pandas, ideally YYYY-MM-DD for a date or YYYY-MM-DD HH:MM:SS for a timestamp. + - The y (numeric) represents the measurement we wish to forecast. + For reference: https://nixtla.github.io/statsforecast/docs/getting-started/getting_started_short.html + """ + aggregated_batch.rename(columns={arg_map["predict"]: "y"}) if "time" in arg_map.keys(): - aggregated_batch.frames.rename(columns={arg_map["time"]: "ds"}) + aggregated_batch.rename(columns={arg_map["time"]: "ds"}) if "id" in arg_map.keys(): - aggregated_batch.frames.rename(columns={arg_map["id"]: "unique_id"}) + aggregated_batch.rename(columns={arg_map["id"]: "unique_id"}) + data = aggregated_batch.frames if "unique_id" not in list(data.columns): data["unique_id"] = ["test" for x in range(len(data))] @@ -219,25 +227,12 @@ def handle_forecasting_function(self): pickle.dump(model, f) f.close() - arg_map_here = {"model_name": model_name, "model_path": model_path} - function = self._try_initializing_function(impl_path, arg_map_here) - io_list = self._resolve_function_io(function) + io_list = self._resolve_function_io(None) metadata_here = [ - FunctionMetadataCatalogEntry( - key="model_name", - value=model_name, - function_id=None, - function_name=None, - row_id=None, - ), - FunctionMetadataCatalogEntry( - key="model_path", - value=model_path, - function_id=None, - function_name=None, - row_id=None, - ), + FunctionMetadataCatalogEntry("model_name", model_name), + FunctionMetadataCatalogEntry("model_path", model_path), + FunctionMetadataCatalogEntry("output_column_rename", arg_map["predict"]), ] return ( diff --git a/evadb/functions/forecast.py b/evadb/functions/forecast.py index eb4e8c0ab..01e7b3f6a 100644 --- a/evadb/functions/forecast.py +++ b/evadb/functions/forecast.py @@ -18,10 +18,8 @@ import pandas as pd -from evadb.catalog.catalog_type import NdArrayType from evadb.functions.abstract.abstract_function import AbstractFunction -from evadb.functions.decorators.decorators import forward, setup -from evadb.functions.decorators.io_descriptors.data_types import PandasDataframe +from evadb.functions.decorators.decorators import setup class ForecastModel(AbstractFunction): @@ -30,35 +28,21 @@ def name(self) -> str: return "ForecastModel" @setup(cacheable=False, function_type="Forecasting", batchable=True) - def setup(self, model_name: str, model_path: str): + def setup(self, model_name: str, model_path: str, output_column_rename: str): f = open(model_path, "rb") loaded_model = pickle.load(f) f.close() self.model = loaded_model self.model_name = model_name + self.output_column_rename = output_column_rename - @forward( - input_signatures=[], - output_signatures=[ - PandasDataframe( - columns=["y"], - column_types=[ - NdArrayType.FLOAT32, - ], - column_shapes=[(None,)], - ) - ], - ) def forward(self, data) -> pd.DataFrame: horizon = list(data.iloc[:, -1])[0] assert ( type(horizon) is int ), "Forecast UDF expects integral horizon in parameter." forecast_df = self.model.predict(h=horizon) - forecast_df = forecast_df.rename(columns={self.model_name: "y"}) - return pd.DataFrame( - forecast_df, - columns=[ - "y", - ], + forecast_df = forecast_df.rename( + columns={self.model_name: self.output_column_rename} ) + return forecast_df diff --git a/test/integration_tests/long/test_model_forecasting.py b/test/integration_tests/long/test_model_forecasting.py index 288eb8dd7..f48a73f7e 100644 --- a/test/integration_tests/long/test_model_forecasting.py +++ b/test/integration_tests/long/test_model_forecasting.py @@ -37,21 +37,37 @@ def setUpClass(cls): y INTEGER);""" execute_query_fetch_all(cls.evadb, create_table_query) + create_table_query = """ + CREATE TABLE HomeData (\ + saledate TEXT(30),\ + ma INTEGER, + type TEXT(30),\ + bedrooms INTEGER);""" + execute_query_fetch_all(cls.evadb, create_table_query) + path = f"{EvaDB_ROOT_DIR}/data/forecasting/air-passengers.csv" load_query = f"LOAD CSV '{path}' INTO AirData;" execute_query_fetch_all(cls.evadb, load_query) + path = f"{EvaDB_ROOT_DIR}/data/forecasting/home_sales.csv" + load_query = f"LOAD CSV '{path}' INTO HomeData;" + execute_query_fetch_all(cls.evadb, load_query) + @classmethod def tearDownClass(cls): shutdown_ray() # clean up - execute_query_fetch_all(cls.evadb, "DROP TABLE IF EXISTS HomeRentals;") + execute_query_fetch_all(cls.evadb, "DROP TABLE IF EXISTS AirData;") + execute_query_fetch_all(cls.evadb, "DROP TABLE IF EXISTS HomeData;") + + execute_query_fetch_all(cls.evadb, "DROP FUNCTION IF EXISTS AirForecast;") + execute_query_fetch_all(cls.evadb, "DROP FUNCTION IF EXISTS HomeForecast;") @forecast_skip_marker def test_forecast(self): create_predict_udf = """ - CREATE FUNCTION Forecast FROM + CREATE FUNCTION AirForecast FROM (SELECT unique_id, ds, y FROM AirData) TYPE Forecasting PREDICT 'y'; @@ -59,11 +75,32 @@ def test_forecast(self): execute_query_fetch_all(self.evadb, create_predict_udf) predict_query = """ - SELECT Forecast(12) FROM AirData; + SELECT AirForecast(12) FROM AirData; """ result = execute_query_fetch_all(self.evadb, predict_query) self.assertEqual(int(list(result.frames.iloc[:, -1])[-1]), 459) + @forecast_skip_marker + def test_forecast_with_column_rename(self): + create_predict_udf = """ + CREATE FUNCTION HomeForecast FROM + ( + SELECT saledate, ma FROM HomeData + WHERE type = "house" AND bedrooms = 2 + ) + TYPE Forecasting + PREDICT 'ma' + TIME 'saledate'; + """ + execute_query_fetch_all(self.evadb, create_predict_udf) + + predict_query = """ + SELECT HomeForecast(12) FROM AirData; + """ + result = execute_query_fetch_all(self.evadb, predict_query) + self.assertEqual(len(result), 12) + self.assertEqual(result.columns, ["homeforecast.ma"]) + if __name__ == "__main__": unittest.main()