Skip to content

Commit

Permalink
add test for unusual time frequency
Browse files Browse the repository at this point in the history
  • Loading branch information
wgifford committed Apr 23, 2024
1 parent 30ce436 commit 1f1f89a
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 10 deletions.
17 changes: 15 additions & 2 deletions tests/toolkit/test_time_series_preprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,18 @@ def test_create_timestamps():
2,
[103.5, 107.0],
),
(
pd.Timestamp(2021, 12, 31),
"QE",
None,
4,
[
pd.Timestamp(2022, 3, 31),
pd.Timestamp(2022, 6, 30),
pd.Timestamp(2022, 9, 30),
pd.Timestamp(2022, 12, 31),
],
),
]

for start, freq, sequence, periods, expected in test_cases:
Expand All @@ -220,8 +232,9 @@ def test_create_timestamps():
assert ts == expected

# test based on provided sequence
ts = create_timestamps(start, time_sequence=sequence, periods=periods)
assert ts == expected
if sequence is not None:
ts = create_timestamps(start, time_sequence=sequence, periods=periods)
assert ts == expected

# it is an error to provide neither freq or sequence
with pytest.raises(ValueError):
Expand Down
26 changes: 18 additions & 8 deletions tsfm_public/toolkit/time_series_preprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -759,14 +759,24 @@ def create_timestamps(

# more complex logic is required to support all edge cases
if isinstance(freq, (pd.Timedelta, datetime.timedelta, str)):
# if isinstance(freq, str):
# freq = pd._libs.tslibs.timedeltas.Timedelta(freq)

return pd.date_range(
last_timestamp,
freq=freq,
periods=periods + 1,
).tolist()[1:]
try:
# try date range directly
return pd.date_range(
last_timestamp,
freq=freq,
periods=periods + 1,
).tolist()[1:]
except ValueError as e:
# if it fails, we can try to compute a timedelta from the provided string
if isinstance(freq, str):
freq = pd._libs.tslibs.timedeltas.Timedelta(freq)
return pd.date_range(
last_timestamp,
freq=freq,
periods=periods + 1,
).tolist()[1:]
else:
raise e
else:
# numerical timestamp column
return [last_timestamp + i * freq for i in range(1, periods + 1)]
Expand Down

0 comments on commit 1f1f89a

Please sign in to comment.