Skip to content

Commit

Permalink
Merge pull request #31 from Nixtla/v2-issues
Browse files Browse the repository at this point in the history
hotfix: error raised for short series
  • Loading branch information
MMenchero authored Oct 14, 2024
2 parents 63c9fc2 + a450586 commit a9ac12b
Showing 1 changed file with 8 additions and 20 deletions.
28 changes: 8 additions & 20 deletions R/nixtla_client_forecast.R
Original file line number Diff line number Diff line change
Expand Up @@ -60,20 +60,11 @@ nixtla_client_forecast <- function(df, h=8, freq=NULL, id_col="unique_id", time_
# Infer frequency if necessary ----
freq <- infer_frequency(df, freq)

# Generate fitted values if required ----
if(add_history){
fitted <- nixtla_client_historic(df=df, freq=freq, id_col=id_col, time_col=time_col, target_col=target_col, level=level, quantiles=quantiles, finetune_steps=finetune_steps, finetune_loss=finetune_loss, clean_ex_first=clean_ex_first)
}

# Obtain model parameters ----
model_params <- .get_model_params(model, freq)

# Make sure there is enough data ----
if(h > model_params$horizon){
message("The specified horizon h exceeds the model horizon. This may lead to less accurate forecasts. Please consider using a smaller horizon.")
}

if(finetune_steps > 0 | !is.null(level)){
# Validate input size ----
if(finetune_steps > 0 | !is.null(level) | add_history){
num_rows <- df |>
dplyr::group_by(.data$unique_id) |>
dplyr::summarise(initial_size = dplyr::n())
Expand All @@ -83,25 +74,21 @@ nixtla_client_forecast <- function(df, h=8, freq=NULL, id_col="unique_id", time_
}
}

# Make sure there is enough data ----
if(h > model_params$horizon){
message("The specified horizon h exceeds the model horizon. This may lead to less accurate forecasts. Please consider using a smaller horizon.")
}

# Restrict input if necessary ----
contains_exogenous <- any(!(names(df) %in% c("unique_id", "ds", "y")))

if(!contains_exogenous & finetune_steps == 0 & !add_history){
# Input is restricted only when there are no exogenous variables and no finetuning
if(is.null(level) & is.null(quantiles)){
input_samples = model_params$input_size
}else{
input_samples = 3*model_params$input_size+max(model_params$horizon, h)
}

num_rows <- df |>
dplyr::group_by(.data$unique_id) |>
dplyr::summarise(initial_size = dplyr::n())

if (any(input_samples > num_rows$initial_size)){
stop(paste0("Your time series is too short. Please make sure that each of your series contains at least ", model_params$input_size+model_params$horizon, " observations."))
}

df <- df |>
dplyr::group_by(.data$unique_id) |>
dplyr::slice_tail(n = input_samples) |>
Expand Down Expand Up @@ -267,6 +254,7 @@ nixtla_client_forecast <- function(df, h=8, freq=NULL, id_col="unique_id", time_

# Add fitted values if required ----
if(add_history){
fitted <- nixtla_client_historic(df=df, freq=freq, id_col=id_col, time_col=time_col, target_col=target_col, level=level, quantiles=quantiles, finetune_steps=finetune_steps, finetune_loss=finetune_loss, clean_ex_first=clean_ex_first)
forecast <- dplyr::bind_rows(fitted, forecast)
}

Expand Down

0 comments on commit a9ac12b

Please sign in to comment.