From 76be54581ed036dc48027fca75d0a868337b6727 Mon Sep 17 00:00:00 2001 From: Sima Najafzadehkhoei Date: Tue, 10 Dec 2024 15:19:00 -0700 Subject: [PATCH] new change in calling directory in calibrate_sir --- R/calibrate_sir.R | 100 ++++++++++++++++++++++----------------------- epiworld-benchmark | 1 - 2 files changed, 50 insertions(+), 51 deletions(-) delete mode 160000 epiworld-benchmark diff --git a/R/calibrate_sir.R b/R/calibrate_sir.R index cba7ce5..fc84081 100644 --- a/R/calibrate_sir.R +++ b/R/calibrate_sir.R @@ -21,62 +21,62 @@ #' @details #' The function determines which pre-trained CNN model to load based on the number of features (columns) in the input `data`. If `data` has 30 columns, it loads the `sir30-cnn.keras` model; if it has 60 columns, it loads the `sir60-cnn.keras` model. Ensure that the input data matches one of these expected formats to avoid errors. #' @export -calibrate_sir <- function(data) { - library(keras3) - ans=preprocessing_data(data) - a=length(ans) - ans <- tensorflow::array_reshape(ans, dim = c(1, 1, a, 1)) - - if(a <=30){ - model <- keras3::load_model( - system.file("models", "sir30-cnn.keras", package = "epiworldRcalibrate") - ) - } - else{ - model <- keras3::load_model( - system.file("models", "sir60-cnn.keras", package = "epiworldRcalibrate") - ) - } - pred <- predict(model, x =ans ) |> - data.table::as.data.table() |> - data.table::setnames(c("preval","crate","ptran","prec")) - pred$crate=qlogis(pred$crate) - - return(list(pred = pred)) -} # calibrate_sir <- function(data) { -# # Load required libraries -# library(tensorflow) -# library(data.table) -# -# # Preprocess the data -# ans <- preprocessing_data(data) -# a <- length(ans) -# ans <- tensorflow::array_reshape(ans, dim = c(1, 1, a, 1)) # Reshape for the model +# library(keras3) +# ans=preprocessing_data(data) +# a=length(ans) +# ans <- tensorflow::array_reshape(ans, dim = c(1, 1, a, 1)) # -# # Determine model file path -# model_path <- if (a <= 31) { -# system.file("models", "sir30-cnn.keras", package = "epiworldRcalibrate") -# } else { -# system.file("models", "sir60-cnn.keras", package = "epiworldRcalibrate") +# if(a <=30){ +# model <- keras3::load_model( +# system.file("models", "sir30-cnn.keras", package = "epiworldRcalibrate") +# ) # } -# -# # Check if the model file exists -# if (model_path == "") { -# stop("Model file not found. Please ensure the models are included in the 'epiworldRcalibrate' package.") +# else{ +# model <- keras3::load_model( +# system.file("models", "sir60-cnn.keras", package = "epiworldRcalibrate") +# ) # } -# -# # Load the model using tensorflow -# model <- tensorflow::tf$keras$models$load_model(model_path) -# -# # Make predictions -# pred <- model$predict(ans) |> +# pred <- predict(model, x =ans ) |> # data.table::as.data.table() |> -# data.table::setnames(c("preval", "crate", "ptran", "prec")) -# +# data.table::setnames(c("preval","crate","ptran","prec")) +# pred$crate=qlogis(pred$crate) # -# # Return predictions as a list # return(list(pred = pred)) # } -# +calibrate_sir <- function(data) { + # Load required libraries + library(tensorflow) + library(data.table) + + # Preprocess the data + ans <- preprocessing_data(data) + a <- length(ans) + ans <- tensorflow::array_reshape(ans, dim = c(1, 1, a, 1)) # Reshape for the model + + # Determine model file path + model_path <- if (a <= 31) { + system.file("models", "sir30-cnn.keras", package = "epiworldRcalibrate") + } else { + system.file("models", "sir60-cnn.keras", package = "epiworldRcalibrate") + } + + # Check if the model file exists + if (model_path == "") { + stop("Model file not found. Please ensure the models are included in the 'epiworldRcalibrate' package.") + } + + # Load the model using tensorflow + model <- tensorflow::tf$keras$models$load_model(model_path) + + # Make predictions + pred <- model$predict(ans) |> + data.table::as.data.table() |> + data.table::setnames(c("preval", "crate", "ptran", "prec")) + + + # Return predictions as a list + return(list(pred = pred)) +} + diff --git a/epiworld-benchmark b/epiworld-benchmark deleted file mode 160000 index bc52952..0000000 --- a/epiworld-benchmark +++ /dev/null @@ -1 +0,0 @@ -Subproject commit bc5295249719dbda9565fa6c0ea25c6fe7a7b175