Skip to content

Commit

Permalink
changes in structure and some code cleaning
Browse files Browse the repository at this point in the history
  • Loading branch information
LukaszChrostowski committed Dec 26, 2024
1 parent dfa876a commit 991b1ff
Show file tree
Hide file tree
Showing 5 changed files with 123 additions and 60 deletions.
86 changes: 86 additions & 0 deletions R/internals.R
Original file line number Diff line number Diff line change
Expand Up @@ -178,3 +178,89 @@ ff <- function(formula) {
l = l
)
}

mu_hatDR <- function(y,
y_nons,
y_rand,
weights,
weights_nons,
weights_rand,
N_nons,
N_rand) {
correction_term <- sum(weights * weights_nons * (y - y_nons)) / N_nons
probability_estimate <- sum(weights_rand * y_rand) / N_rand
correction_term + probability_estimate
}

mu_hatIPW <- function(y,
weights,
weights_nons,
N) {
mu_hat <- sum(weights * weights_nons * y) / N
mu_hat
}

nonprobMI_fit <- function(outcome,
data,
weights,
svydesign = NULL,
family_outcome = "gaussian",
start = NULL,
control_outcome = controlOut(),
verbose = FALSE,
model = TRUE,
x = FALSE,
y = FALSE) {
# Process family specification
family <- process_family(family_outcome)

# Process control parameters
control_list <- list(
epsilon = control_outcome$epsilon,
maxit = control_outcome$maxit,
trace = control_outcome$trace
)

# Create model environment to avoid modifying original data
model_data <- data
model_data$weights <- weights

# Fit the model
tryCatch({
model_fit <- stats::glm(
formula = outcome,
data = model_data,
weights = weights,
family = family,
start = start,
control = control_list,
model = model,
x = x,
y = y
)

if (verbose) {
cat("Model fitting completed:\n")
cat("Convergence status:", ifelse(model_fit$converged, "converged", "not converged"), "\n")
cat("Number of iterations:", model_fit$iter, "\n")
}

return(model_fit)

}, error = function(e) {
stop("Error in model fitting: ", e$message)
})
}

process_family <- function(family_spec) {
if (is.character(family_spec)) {
family <- get(family_spec, mode = "function", envir = parent.frame())
} else if (is.function(family_spec)) {
family <- family_spec()
} else if (inherits(family_spec, "family")) {
family <- family_spec
} else {
stop("Invalid family specification")
}
return(family)
}
38 changes: 37 additions & 1 deletion R/nonprob.R
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,18 @@ nonprob <- function(data,
if (is.null(selection) & is.null(outcome)) {
stop("Please provide selection or outcome formula.")
}

# Check formula inputs
if (!is.null(selection) && !inherits(selection, "formula")) {
stop("'selection' must be a formula")
}
if (!is.null(outcome) && !inherits(outcome, "formula")) {
stop("'outcome' must be a formula")
}
if (!is.null(target) && !inherits(target, "formula")) {
stop("'target' must be a formula")
}

if (inherits(selection, "formula") && (is.null(outcome) || inherits(outcome, "formula") == FALSE)) {
if (inherits(target, "formula") == FALSE) stop("Please provide target variable")
model_used <- "P"
Expand All @@ -71,7 +83,31 @@ nonprob <- function(data,
model_used <- "DR"
}

## validate data
# Check numeric inputs
if (!is.null(pop_size) && !is.numeric(pop_size)) {
stop("'pop_size' must be numeric")
}
if (!is.null(weights) && !is.numeric(weights)) {
stop("'weights' must be numeric")
}

# Check weights length
if (!is.null(weights) && length(weights) != nrow(data)) {
stop("Length of weights must match number of rows in data")
}

if (!is.null(pop_totals) && !is.null(pop_means)) {
stop("Cannot specify both pop_totals and pop_means")
}

if (!is.null(pop_size)) {
if (pop_size <= 0) {
stop("pop_size must be positive")
}
if (pop_size < nrow(data)) {
stop("pop_size cannot be smaller than sample size")
}
}

## model estimates
model_estimates <- switch(model_used,
Expand Down
12 changes: 0 additions & 12 deletions R/nonprobDR.R
Original file line number Diff line number Diff line change
Expand Up @@ -911,15 +911,3 @@ nonprobDR <- function(selection,
class = c("nonprobsvy", "nonprobsvy_dr")
)
}

mu_hatDR <- function(y,
y_nons,
y_rand,
weights,
weights_nons,
weights_rand,
N_nons,
N_rand) {
mu_hat <- 1 / N_nons * sum(weights * weights_nons * (y - y_nons)) + 1 / N_rand * sum(weights_rand * y_rand)
mu_hat
}
9 changes: 0 additions & 9 deletions R/nonprobIPW.R
Original file line number Diff line number Diff line change
Expand Up @@ -603,12 +603,3 @@ nonprobIPW <- function(selection,
class = c("nonprobsvy", "nonprobsvy_ipw")
)
}


mu_hatIPW <- function(y,
weights,
weights_nons,
N) {
mu_hat <- 1 / N * sum(weights * weights_nons * y)
mu_hat
}
38 changes: 0 additions & 38 deletions R/nonprobMI.R
Original file line number Diff line number Diff line change
Expand Up @@ -499,41 +499,3 @@ nonprobMI <- function(outcome,
class = c("nonprobsvy", "nonprobsvy_mi")
)
}



nonprobMI_fit <- function(outcome,
data,
weights,
svydesign,
family_outcome,
start,
control_outcome = controlOut(),
verbose,
model,
x,
y) {
family <- family_outcome

if (is.character(family)) {
family <- get(family, mode = "function", envir = parent.frame())
}
if (is.function(family)) {
family <- family()
}
data$weights <- weights # TODO just for now, find more efficient way
model_nons <- stats::glm(
formula = outcome,
data = data,
weights = weights,
family = family,
start = start,
control = list(
control_outcome$epsilon,
control_outcome$maxit,
control_outcome$trace
)
)

model_nons
}

0 comments on commit 991b1ff

Please sign in to comment.