Skip to content

Commit

Permalink
new change in calling directory in calibrate_sir
Browse files Browse the repository at this point in the history
  • Loading branch information
sima-njf committed Dec 10, 2024
1 parent 2ed39c8 commit 76be545
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 51 deletions.
100 changes: 50 additions & 50 deletions R/calibrate_sir.R
Original file line number Diff line number Diff line change
Expand Up @@ -21,62 +21,62 @@
#' @details
#' The function determines which pre-trained CNN model to load based on the number of features (columns) in the input `data`. If `data` has 30 columns, it loads the `sir30-cnn.keras` model; if it has 60 columns, it loads the `sir60-cnn.keras` model. Ensure that the input data matches one of these expected formats to avoid errors.
#' @export
calibrate_sir <- function(data) {
library(keras3)
ans=preprocessing_data(data)
a=length(ans)
ans <- tensorflow::array_reshape(ans, dim = c(1, 1, a, 1))

if(a <=30){
model <- keras3::load_model(
system.file("models", "sir30-cnn.keras", package = "epiworldRcalibrate")
)
}
else{
model <- keras3::load_model(
system.file("models", "sir60-cnn.keras", package = "epiworldRcalibrate")
)
}
pred <- predict(model, x =ans ) |>
data.table::as.data.table() |>
data.table::setnames(c("preval","crate","ptran","prec"))
pred$crate=qlogis(pred$crate)

return(list(pred = pred))
}
# calibrate_sir <- function(data) {
# # Load required libraries
# library(tensorflow)
# library(data.table)
#
# # Preprocess the data
# ans <- preprocessing_data(data)
# a <- length(ans)
# ans <- tensorflow::array_reshape(ans, dim = c(1, 1, a, 1)) # Reshape for the model
# library(keras3)
# ans=preprocessing_data(data)
# a=length(ans)
# ans <- tensorflow::array_reshape(ans, dim = c(1, 1, a, 1))
#
# # Determine model file path
# model_path <- if (a <= 31) {
# system.file("models", "sir30-cnn.keras", package = "epiworldRcalibrate")
# } else {
# system.file("models", "sir60-cnn.keras", package = "epiworldRcalibrate")
# if(a <=30){
# model <- keras3::load_model(
# system.file("models", "sir30-cnn.keras", package = "epiworldRcalibrate")
# )
# }
#
# # Check if the model file exists
# if (model_path == "") {
# stop("Model file not found. Please ensure the models are included in the 'epiworldRcalibrate' package.")
# else{
# model <- keras3::load_model(
# system.file("models", "sir60-cnn.keras", package = "epiworldRcalibrate")
# )
# }
#
# # Load the model using tensorflow
# model <- tensorflow::tf$keras$models$load_model(model_path)
#
# # Make predictions
# pred <- model$predict(ans) |>
# pred <- predict(model, x =ans ) |>
# data.table::as.data.table() |>
# data.table::setnames(c("preval", "crate", "ptran", "prec"))
#
# data.table::setnames(c("preval","crate","ptran","prec"))
# pred$crate=qlogis(pred$crate)
#
# # Return predictions as a list
# return(list(pred = pred))
# }
#
calibrate_sir <- function(data) {
# Load required libraries
library(tensorflow)
library(data.table)

# Preprocess the data
ans <- preprocessing_data(data)
a <- length(ans)
ans <- tensorflow::array_reshape(ans, dim = c(1, 1, a, 1)) # Reshape for the model

# Determine model file path
model_path <- if (a <= 31) {
system.file("models", "sir30-cnn.keras", package = "epiworldRcalibrate")
} else {
system.file("models", "sir60-cnn.keras", package = "epiworldRcalibrate")
}

# Check if the model file exists
if (model_path == "") {
stop("Model file not found. Please ensure the models are included in the 'epiworldRcalibrate' package.")
}

# Load the model using tensorflow
model <- tensorflow::tf$keras$models$load_model(model_path)

# Make predictions
pred <- model$predict(ans) |>
data.table::as.data.table() |>
data.table::setnames(c("preval", "crate", "ptran", "prec"))


# Return predictions as a list
return(list(pred = pred))
}


1 change: 0 additions & 1 deletion epiworld-benchmark
Submodule epiworld-benchmark deleted from bc5295

0 comments on commit 76be545

Please sign in to comment.