Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Issue 38 #39

Merged
merged 7 commits into from
Apr 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 8 additions & 9 deletions .vscode/settings.json
Original file line number Diff line number Diff line change
@@ -1,18 +1,17 @@
{
"[python]": {
"editor.defaultFormatter": "ms-python.black-formatter",
"editor.defaultFormatter": "charliermarsh.ruff",
"editor.formatOnSave": true,
"editor.codeActionsOnSave": {
"source.organizeImports": "explicit"
}
},
"isort.args": [
"--profile",
"black"
],
"python.testing.pytestArgs": [
"tests"
],
"python.testing.unittestEnabled": false,
"python.testing.pytestEnabled": true
"python.testing.pytestEnabled": true,
"ruff.format.args": [
"--config=./pyproject.toml"
],
"ruff.lint.args": [
"--config=./pyproject.toml"
]
}
17 changes: 15 additions & 2 deletions tests/toolkit/test_time_series_preprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,18 @@ def test_create_timestamps():
2,
[103.5, 107.0],
),
(
pd.Timestamp(2021, 12, 31),
"QE",
None,
4,
[
pd.Timestamp(2022, 3, 31),
pd.Timestamp(2022, 6, 30),
pd.Timestamp(2022, 9, 30),
pd.Timestamp(2022, 12, 31),
],
),
]

for start, freq, sequence, periods, expected in test_cases:
Expand All @@ -220,8 +232,9 @@ def test_create_timestamps():
assert ts == expected

# test based on provided sequence
ts = create_timestamps(start, time_sequence=sequence, periods=periods)
assert ts == expected
if sequence is not None:
ts = create_timestamps(start, time_sequence=sequence, periods=periods)
assert ts == expected

# it is an error to provide neither freq or sequence
with pytest.raises(ValueError):
Expand Down
2 changes: 1 addition & 1 deletion tsfm_public/toolkit/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,7 @@ def get_group_data(
):
return cls(
data_df=group,
group_id=group_id,
group_id=group_id if isinstance(group_id, tuple) else (group_id,),
id_columns=id_columns,
timestamp_column=timestamp_column,
context_length=context_length,
Expand Down
4 changes: 4 additions & 0 deletions tsfm_public/toolkit/time_series_forecasting_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,10 @@ def postprocess(self, input, **kwargs):
"""Postprocess step
Takes the dictionary of outputs from the previous step and converts to a more user
readable pandas format.

If the explode forecasts option is True, then individual forecasts are expanded as multiple
rows in the dataframe. This should only be used when producing a single forecast (i.e., unexploded
result is one row per ID).
"""
out = {}

Expand Down
54 changes: 43 additions & 11 deletions tsfm_public/toolkit/time_series_preprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,11 +344,29 @@ def _standardize_dataframe(

return df

def _clean_up_dataframe(self, df: pd.DataFrame) -> None:
"""Removes columns added during internal processing of the provided dataframe.

Currently, the following checks are done:
- Remove INTERNAL_ID_COLUMN if present

Args:
df (pd.DataFrame): Input pandas dataframe

Returns:
pd.DataFrame: Cleaned up dataframe
"""

if not self.id_columns:
if INTERNAL_ID_COLUMN in df.columns:
df.drop(columns=INTERNAL_ID_COLUMN, inplace=True)

def _get_groups(
self,
dataset: pd.DataFrame,
) -> Generator[Tuple[Any, pd.DataFrame], None, None]:
"""Get groups of the time series dataset (multi-time series) based on the ID columns.
"""Get groups of the time series dataset (multi-time series) based on the ID columns for scaling.
Note that this is used for scaling purposes only.

Args:
dataset (pd.DataFrame): Input dataset
Expand Down Expand Up @@ -472,7 +490,7 @@ def _check_dataset(self, dataset: Union[Dataset, pd.DataFrame]):

def _set_targets(self, dataset: pd.DataFrame) -> None:
if self.target_columns == []:
skip_columns = copy.copy(self.id_columns)
skip_columns = copy.copy(self.id_columns) + [INTERNAL_ID_COLUMN]
if self.timestamp_column:
skip_columns.append(self.timestamp_column)

Expand Down Expand Up @@ -531,6 +549,7 @@ def train(
if self.encode_categorical:
self._train_categorical_encoder(df)

self._clean_up_dataframe(df)
return self

def inverse_scale_targets(
Expand Down Expand Up @@ -581,10 +600,12 @@ def inverse_scale_func(grp, id_columns):
else:
id_columns = INTERNAL_ID_COLUMN

return df.groupby(id_columns, group_keys=False).apply(
df_inv = df.groupby(id_columns, group_keys=False).apply(
inverse_scale_func,
id_columns=id_columns,
)
self._clean_up_dataframe(df_inv)
return df_inv

def preprocess(
self,
Expand Down Expand Up @@ -640,6 +661,7 @@ def scale_func(grp, id_columns):
raise RuntimeError("Attempt to encode categorical columns, but the encoder has not been trained yet.")
df[cols_to_encode] = self.categorical_encoder.transform(df[cols_to_encode])

self._clean_up_dataframe(df)
return df

def get_datasets(
Expand Down Expand Up @@ -759,14 +781,24 @@ def create_timestamps(

# more complex logic is required to support all edge cases
if isinstance(freq, (pd.Timedelta, datetime.timedelta, str)):
if isinstance(freq, str):
freq = pd._libs.tslibs.timedeltas.Timedelta(freq)

return pd.date_range(
last_timestamp,
freq=freq,
periods=periods + 1,
).tolist()[1:]
try:
# try date range directly
return pd.date_range(
last_timestamp,
freq=freq,
periods=periods + 1,
).tolist()[1:]
except ValueError as e:
# if it fails, we can try to compute a timedelta from the provided string
if isinstance(freq, str):
freq = pd._libs.tslibs.timedeltas.Timedelta(freq)
return pd.date_range(
last_timestamp,
freq=freq,
periods=periods + 1,
).tolist()[1:]
else:
raise e
else:
# numerical timestamp column
return [last_timestamp + i * freq for i in range(1, periods + 1)]
Expand Down
Loading