From c36819d70e2c22de3a576d1b7e9b605a4ad7ba8c Mon Sep 17 00:00:00 2001 From: "Wesley M. Gifford" Date: Wed, 10 Apr 2024 21:31:10 -0400 Subject: [PATCH 1/6] ensure ids are a tuple Signed-off-by: Wesley M. Gifford --- tsfm_public/toolkit/dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tsfm_public/toolkit/dataset.py b/tsfm_public/toolkit/dataset.py index 27f782b3..57cb9c6c 100644 --- a/tsfm_public/toolkit/dataset.py +++ b/tsfm_public/toolkit/dataset.py @@ -232,7 +232,7 @@ def get_group_data( ): return cls( data_df=group, - group_id=group_id, + group_id=group_id if isinstance(group_id, tuple) else (group_id,), id_columns=id_columns, timestamp_column=timestamp_column, context_length=context_length, From 55a53ff996cf4acaacc0df12af3ba5f7b6d5eb10 Mon Sep 17 00:00:00 2001 From: "Wesley M. Gifford" Date: Wed, 10 Apr 2024 21:31:31 -0400 Subject: [PATCH 2/6] consistently use ruff Signed-off-by: Wesley M. Gifford --- .vscode/settings.json | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/.vscode/settings.json b/.vscode/settings.json index 913daa06..bd38a513 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -1,18 +1,17 @@ { "[python]": { - "editor.defaultFormatter": "ms-python.black-formatter", + "editor.defaultFormatter": "charliermarsh.ruff", "editor.formatOnSave": true, - "editor.codeActionsOnSave": { - "source.organizeImports": "explicit" - } }, - "isort.args": [ - "--profile", - "black" - ], "python.testing.pytestArgs": [ "tests" ], "python.testing.unittestEnabled": false, - "python.testing.pytestEnabled": true + "python.testing.pytestEnabled": true, + "ruff.format.args": [ + "--config=./pyproject.toml" + ], + "ruff.lint.args": [ + "--config=./pyproject.toml" + ] } From 94d96ad9d4350cfdc9d06873765ffe9cf7c160c8 Mon Sep 17 00:00:00 2001 From: "Wesley M. Gifford" Date: Tue, 23 Apr 2024 11:42:25 -0400 Subject: [PATCH 3/6] Add docstring Signed-off-by: Wesley M. Gifford --- tsfm_public/toolkit/time_series_forecasting_pipeline.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tsfm_public/toolkit/time_series_forecasting_pipeline.py b/tsfm_public/toolkit/time_series_forecasting_pipeline.py index d8a3b328..15f55db6 100644 --- a/tsfm_public/toolkit/time_series_forecasting_pipeline.py +++ b/tsfm_public/toolkit/time_series_forecasting_pipeline.py @@ -300,6 +300,10 @@ def postprocess(self, input, **kwargs): """Postprocess step Takes the dictionary of outputs from the previous step and converts to a more user readable pandas format. + + If the explode forecasts option is True, then individual forecasts are expanded as multiple + rows in the dataframe. This should only be used when producing a single forecast (i.e., unexploded + result is one row per ID). """ out = {} From 30ce4367ee98a2785f97be10403fd526267e07f8 Mon Sep 17 00:00:00 2001 From: "Wesley M. Gifford" Date: Tue, 23 Apr 2024 12:05:20 -0400 Subject: [PATCH 4/6] try to handle freq directly Signed-off-by: Wesley M. Gifford --- tsfm_public/toolkit/time_series_preprocessor.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tsfm_public/toolkit/time_series_preprocessor.py b/tsfm_public/toolkit/time_series_preprocessor.py index b88dcad7..a69db0a6 100644 --- a/tsfm_public/toolkit/time_series_preprocessor.py +++ b/tsfm_public/toolkit/time_series_preprocessor.py @@ -759,8 +759,8 @@ def create_timestamps( # more complex logic is required to support all edge cases if isinstance(freq, (pd.Timedelta, datetime.timedelta, str)): - if isinstance(freq, str): - freq = pd._libs.tslibs.timedeltas.Timedelta(freq) + # if isinstance(freq, str): + # freq = pd._libs.tslibs.timedeltas.Timedelta(freq) return pd.date_range( last_timestamp, From 493d42c2a4d2675cf018c9a492ddc716c71d408c Mon Sep 17 00:00:00 2001 From: "Wesley M. Gifford" Date: Tue, 23 Apr 2024 12:28:36 -0400 Subject: [PATCH 5/6] add test for unusual time frequency Signed-off-by: Wesley M. Gifford --- .../toolkit/test_time_series_preprocessor.py | 17 ++++++++++-- .../toolkit/time_series_preprocessor.py | 26 +++++++++++++------ 2 files changed, 33 insertions(+), 10 deletions(-) diff --git a/tests/toolkit/test_time_series_preprocessor.py b/tests/toolkit/test_time_series_preprocessor.py index 9c052fec..5a813c77 100644 --- a/tests/toolkit/test_time_series_preprocessor.py +++ b/tests/toolkit/test_time_series_preprocessor.py @@ -212,6 +212,18 @@ def test_create_timestamps(): 2, [103.5, 107.0], ), + ( + pd.Timestamp(2021, 12, 31), + "QE", + None, + 4, + [ + pd.Timestamp(2022, 3, 31), + pd.Timestamp(2022, 6, 30), + pd.Timestamp(2022, 9, 30), + pd.Timestamp(2022, 12, 31), + ], + ), ] for start, freq, sequence, periods, expected in test_cases: @@ -220,8 +232,9 @@ def test_create_timestamps(): assert ts == expected # test based on provided sequence - ts = create_timestamps(start, time_sequence=sequence, periods=periods) - assert ts == expected + if sequence is not None: + ts = create_timestamps(start, time_sequence=sequence, periods=periods) + assert ts == expected # it is an error to provide neither freq or sequence with pytest.raises(ValueError): diff --git a/tsfm_public/toolkit/time_series_preprocessor.py b/tsfm_public/toolkit/time_series_preprocessor.py index a69db0a6..68bfa5db 100644 --- a/tsfm_public/toolkit/time_series_preprocessor.py +++ b/tsfm_public/toolkit/time_series_preprocessor.py @@ -759,14 +759,24 @@ def create_timestamps( # more complex logic is required to support all edge cases if isinstance(freq, (pd.Timedelta, datetime.timedelta, str)): - # if isinstance(freq, str): - # freq = pd._libs.tslibs.timedeltas.Timedelta(freq) - - return pd.date_range( - last_timestamp, - freq=freq, - periods=periods + 1, - ).tolist()[1:] + try: + # try date range directly + return pd.date_range( + last_timestamp, + freq=freq, + periods=periods + 1, + ).tolist()[1:] + except ValueError as e: + # if it fails, we can try to compute a timedelta from the provided string + if isinstance(freq, str): + freq = pd._libs.tslibs.timedeltas.Timedelta(freq) + return pd.date_range( + last_timestamp, + freq=freq, + periods=periods + 1, + ).tolist()[1:] + else: + raise e else: # numerical timestamp column return [last_timestamp + i * freq for i in range(1, periods + 1)] From fc09e1ad571709d8becea1ee5bc2cc77f94a9b3f Mon Sep 17 00:00:00 2001 From: "Wesley M. Gifford" Date: Tue, 23 Apr 2024 13:51:02 -0400 Subject: [PATCH 6/6] clean up output Signed-off-by: Wesley M. Gifford --- .../toolkit/time_series_preprocessor.py | 28 +++++++++++++++++-- 1 file changed, 25 insertions(+), 3 deletions(-) diff --git a/tsfm_public/toolkit/time_series_preprocessor.py b/tsfm_public/toolkit/time_series_preprocessor.py index 68bfa5db..4de30c24 100644 --- a/tsfm_public/toolkit/time_series_preprocessor.py +++ b/tsfm_public/toolkit/time_series_preprocessor.py @@ -344,11 +344,29 @@ def _standardize_dataframe( return df + def _clean_up_dataframe(self, df: pd.DataFrame) -> None: + """Removes columns added during internal processing of the provided dataframe. + + Currently, the following checks are done: + - Remove INTERNAL_ID_COLUMN if present + + Args: + df (pd.DataFrame): Input pandas dataframe + + Returns: + pd.DataFrame: Cleaned up dataframe + """ + + if not self.id_columns: + if INTERNAL_ID_COLUMN in df.columns: + df.drop(columns=INTERNAL_ID_COLUMN, inplace=True) + def _get_groups( self, dataset: pd.DataFrame, ) -> Generator[Tuple[Any, pd.DataFrame], None, None]: - """Get groups of the time series dataset (multi-time series) based on the ID columns. + """Get groups of the time series dataset (multi-time series) based on the ID columns for scaling. + Note that this is used for scaling purposes only. Args: dataset (pd.DataFrame): Input dataset @@ -472,7 +490,7 @@ def _check_dataset(self, dataset: Union[Dataset, pd.DataFrame]): def _set_targets(self, dataset: pd.DataFrame) -> None: if self.target_columns == []: - skip_columns = copy.copy(self.id_columns) + skip_columns = copy.copy(self.id_columns) + [INTERNAL_ID_COLUMN] if self.timestamp_column: skip_columns.append(self.timestamp_column) @@ -531,6 +549,7 @@ def train( if self.encode_categorical: self._train_categorical_encoder(df) + self._clean_up_dataframe(df) return self def inverse_scale_targets( @@ -581,10 +600,12 @@ def inverse_scale_func(grp, id_columns): else: id_columns = INTERNAL_ID_COLUMN - return df.groupby(id_columns, group_keys=False).apply( + df_inv = df.groupby(id_columns, group_keys=False).apply( inverse_scale_func, id_columns=id_columns, ) + self._clean_up_dataframe(df_inv) + return df_inv def preprocess( self, @@ -640,6 +661,7 @@ def scale_func(grp, id_columns): raise RuntimeError("Attempt to encode categorical columns, but the encoder has not been trained yet.") df[cols_to_encode] = self.categorical_encoder.transform(df[cols_to_encode]) + self._clean_up_dataframe(df) return df def get_datasets(