From a450586ec63c6f86e991a439bff780968200e365 Mon Sep 17 00:00:00 2001 From: MMenchero Date: Mon, 14 Oct 2024 14:44:42 -0600 Subject: [PATCH] hotfix: remove error for short series --- R/nixtla_client_forecast.R | 28 ++++++++-------------------- 1 file changed, 8 insertions(+), 20 deletions(-) diff --git a/R/nixtla_client_forecast.R b/R/nixtla_client_forecast.R index 80c4335..03b4099 100644 --- a/R/nixtla_client_forecast.R +++ b/R/nixtla_client_forecast.R @@ -60,20 +60,11 @@ nixtla_client_forecast <- function(df, h=8, freq=NULL, id_col="unique_id", time_ # Infer frequency if necessary ---- freq <- infer_frequency(df, freq) - # Generate fitted values if required ---- - if(add_history){ - fitted <- nixtla_client_historic(df=df, freq=freq, id_col=id_col, time_col=time_col, target_col=target_col, level=level, quantiles=quantiles, finetune_steps=finetune_steps, finetune_loss=finetune_loss, clean_ex_first=clean_ex_first) - } - # Obtain model parameters ---- model_params <- .get_model_params(model, freq) - # Make sure there is enough data ---- - if(h > model_params$horizon){ - message("The specified horizon h exceeds the model horizon. This may lead to less accurate forecasts. Please consider using a smaller horizon.") - } - - if(finetune_steps > 0 | !is.null(level)){ + # Validate input size ---- + if(finetune_steps > 0 | !is.null(level) | add_history){ num_rows <- df |> dplyr::group_by(.data$unique_id) |> dplyr::summarise(initial_size = dplyr::n()) @@ -83,25 +74,21 @@ nixtla_client_forecast <- function(df, h=8, freq=NULL, id_col="unique_id", time_ } } + # Make sure there is enough data ---- + if(h > model_params$horizon){ + message("The specified horizon h exceeds the model horizon. This may lead to less accurate forecasts. Please consider using a smaller horizon.") + } + # Restrict input if necessary ---- contains_exogenous <- any(!(names(df) %in% c("unique_id", "ds", "y"))) if(!contains_exogenous & finetune_steps == 0 & !add_history){ - # Input is restricted only when there are no exogenous variables and no finetuning if(is.null(level) & is.null(quantiles)){ input_samples = model_params$input_size }else{ input_samples = 3*model_params$input_size+max(model_params$horizon, h) } - num_rows <- df |> - dplyr::group_by(.data$unique_id) |> - dplyr::summarise(initial_size = dplyr::n()) - - if (any(input_samples > num_rows$initial_size)){ - stop(paste0("Your time series is too short. Please make sure that each of your series contains at least ", model_params$input_size+model_params$horizon, " observations.")) - } - df <- df |> dplyr::group_by(.data$unique_id) |> dplyr::slice_tail(n = input_samples) |> @@ -267,6 +254,7 @@ nixtla_client_forecast <- function(df, h=8, freq=NULL, id_col="unique_id", time_ # Add fitted values if required ---- if(add_history){ + fitted <- nixtla_client_historic(df=df, freq=freq, id_col=id_col, time_col=time_col, target_col=target_col, level=level, quantiles=quantiles, finetune_steps=finetune_steps, finetune_loss=finetune_loss, clean_ex_first=clean_ex_first) forecast <- dplyr::bind_rows(fitted, forecast) }