Skip to content

Commit

Permalink
feat: infere freq
Browse files Browse the repository at this point in the history
  • Loading branch information
AzulGarza committed Aug 6, 2023
1 parent b3126d6 commit fd36848
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 21 deletions.
56 changes: 37 additions & 19 deletions nbs/timegpt.ipynb

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions nixtlats/_modidx.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
'lib_path': 'nixtlats'},
'syms': { 'nixtlats.timegpt': { 'nixtlats.timegpt.TimeGPT': ('timegpt.html#timegpt', 'nixtlats/timegpt.py'),
'nixtlats.timegpt.TimeGPT.__init__': ('timegpt.html#timegpt.__init__', 'nixtlats/timegpt.py'),
'nixtlats.timegpt.TimeGPT._infer_freq': ('timegpt.html#timegpt._infer_freq', 'nixtlats/timegpt.py'),
'nixtlats.timegpt.TimeGPT._input_size': ('timegpt.html#timegpt._input_size', 'nixtlats/timegpt.py'),
'nixtlats.timegpt.TimeGPT._multi_series': ('timegpt.html#timegpt._multi_series', 'nixtlats/timegpt.py'),
'nixtlats.timegpt.TimeGPT._parse_response': ( 'timegpt.html#timegpt._parse_response',
Expand Down
17 changes: 15 additions & 2 deletions nixtlats/timegpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,17 @@ def _validate_outputs(
fcst_df = fcst_df.rename(columns=renamer)
return fcst_df

def _infer_freq(self, df: pd.DataFrame):
unique_id = df.iloc[0]["unique_id"]
df_id = df.query("unique_id == @unique_id")
freq = pd.infer_freq(df_id["ds"])
if freq is None:
raise Exception(
'"Could not infer frequency of ds column. This could be due to \
inconsistent intervals. Please check your data for missing, duplicated or irregular timestamps."'
)
return freq

def _preprocess_inputs(
self,
df: pd.DataFrame,
Expand Down Expand Up @@ -136,6 +147,8 @@ def _multi_series(
finetune_steps: int = 0,
clean_ex_first: bool = True,
):
if freq is None:
freq = self._infer_freq(df)
y, x, x_cols = self._preprocess_inputs(df=df, h=h, freq=freq, X_df=X_df)
payload = dict(
y=y,
Expand Down Expand Up @@ -165,7 +178,7 @@ def forecast(
self,
df: pd.DataFrame,
h: int,
freq: str,
freq: Optional[str] = None,
id_col: str = "unique_id",
time_col: str = "ds",
target_col: str = "y",
Expand Down Expand Up @@ -193,7 +206,7 @@ def forecast(
h : int
Forecast horizon.
freq : str
Frequency of the data.
Frequency of the data. By default, the freq will be inferred automatically.
See [pandas' available frequencies](https://pandas.pydata.org/pandas-docs/stable/user_guide/timeseries.html#offset-aliases).
id_col : str (default='unique_id')
Column that identifies each serie.
Expand Down

0 comments on commit fd36848

Please sign in to comment.