Skip to content

Commit

Permalink
add test for freq token, update mapping, update freq token logic
Browse files Browse the repository at this point in the history
  • Loading branch information
wgifford committed Jun 5, 2024
1 parent 22b3cb0 commit 34869cb
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 7 deletions.
10 changes: 10 additions & 0 deletions tests/toolkit/test_time_series_preprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import pytest

from tsfm_public.toolkit.time_series_preprocessor import (
DEFAULT_FREQUENCY_MAPPING,
OrdinalEncoder,
StandardScaler,
TimeSeriesPreprocessor,
Expand Down Expand Up @@ -386,6 +387,15 @@ def test_get_datasets_without_targets(ts_data):
train.datasets[0].target_columns == ["value1", "value2"]


def test_get_datasets_with_frequency_token(ts_data):
ts_data = ts_data.drop(columns=["id", "id2"])
tsp = TimeSeriesPreprocessor(timestamp_column="timestamp", prediction_length=2, context_length=5, freq="d")

train, _, _ = get_datasets(tsp, ts_data, split_config={"train": 0.7, "test": 0.2}, use_frequency_token=True)

assert train[0]["freq_token"] == DEFAULT_FREQUENCY_MAPPING["d"]


def test_id_columns_and_scaling_id_columns(ts_data_runs):
df = ts_data_runs

Expand Down
20 changes: 13 additions & 7 deletions tsfm_public/toolkit/time_series_preprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,17 @@
INTERNAL_ID_COLUMN = "__id"
INTERNAL_ID_VALUE = "0"


DEFAULT_FREQUENCY_MAPPING = {
"oov": 0,
"half_hourly": 1,
"hourly": 2,
"10_minutes": 3,
"15_minutes": 4,
"min": 1, # minutely
"2min": 2,
"5min": 3,
"10min": 4,
"15min": 5,
"30min": 6,
"h": 7, # hourly
"d": 8, # daily
"W": 9, # weekly
}


Expand Down Expand Up @@ -153,7 +157,7 @@ def __init__(
encode_categorical (bool, optional): If True any categorical columns will be encoded using ordinal encoding. Defaults to True.
time_series_task (str, optional): Reserved for future use. Defaults to TimeSeriesTask.FORECASTING.value.
frequency_mapping (Dict[str, int], optional): _description_. Defaults to DEFAULT_FREQUENCY_MAPPING.
freq (Optional[Union[int, str]], optional): A freqency indicator for the given `timestamp_column`. See
freq (Optional[Union[int, str]], optional): A frequency indicator for the given `timestamp_column`. See
https://pandas.pydata.org/pandas-docs/stable/user_guide/timeseries.html#period-aliases for a description of the
allowed values. If not provided, we will attempt to infer it from the data. If not provided, frequency will be
inferred from `timestamp_column`. Defaults to None.
Expand Down Expand Up @@ -710,6 +714,7 @@ def get_datasets(
split_config: Dict[str, Union[List[Union[int, float]], float]],
fewshot_fraction: Optional[float] = None,
fewshot_location: str = FractionLocation.LAST.value,
use_frequency_token: bool = False,
) -> Tuple[Any]:
"""Creates the preprocessed pytorch datasets needed for training and evaluation
using the HuggingFace trainer
Expand Down Expand Up @@ -748,6 +753,7 @@ def get_datasets(
split_config=split_config,
fewshot_fraction=fewshot_fraction,
fewshot_location=fewshot_location,
use_frequency_token=use_frequency_token,
)


Expand Down Expand Up @@ -881,7 +887,7 @@ def get_datasets(
params["context_length"] = ts_preprocessor.context_length
params["prediction_length"] = ts_preprocessor.prediction_length
if use_frequency_token:
params["frequency_token"] = ts_preprocessor.get_frequency_token()
params["frequency_token"] = ts_preprocessor.get_frequency_token(ts_preprocessor.freq)

# get torch datasets
train_valid_test = [train_data, valid_data, test_data]
Expand Down

0 comments on commit 34869cb

Please sign in to comment.