Skip to content

Commit

Permalink
Merge pull request #344 from tidymodels/tidy-glmnet
Browse files Browse the repository at this point in the history
better tidy glmnet methods
  • Loading branch information
topepo authored Jul 7, 2020
2 parents bed8ca9 + c769cf0 commit a469d0a
Show file tree
Hide file tree
Showing 10 changed files with 174 additions and 9 deletions.
5 changes: 5 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,10 @@ S3method(print,svm_poly)
S3method(print,svm_rbf)
S3method(req_pkgs,model_fit)
S3method(req_pkgs,model_spec)
S3method(tidy,"_elnet")
S3method(tidy,"_fishnet")
S3method(tidy,"_lognet")
S3method(tidy,"_multnet")
S3method(tidy,model_fit)
S3method(tidy,nullmodel)
S3method(translate,boost_tree)
Expand Down Expand Up @@ -234,6 +238,7 @@ importFrom(stats,.checkMFClasses)
importFrom(stats,.getXlevels)
importFrom(stats,as.formula)
importFrom(stats,binomial)
importFrom(stats,coef)
importFrom(stats,delete.response)
importFrom(stats,model.frame)
importFrom(stats,model.matrix)
Expand Down
2 changes: 2 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# parsnip (development version)

* Specific `tidy()` methods for `glmnet` models fit via `parsnip` were created so that the coefficients for the specific fitted `parsnip` model are returned.

# parsnip 0.1.2

## Breaking Changes
Expand Down
2 changes: 1 addition & 1 deletion R/aaa.R
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ utils::globalVariables(
"neighbors", ".submodels", "has_submodel", "max_neighbor", "max_penalty",
"max_terms", "max_tree", "model", "name", "num_terms", "penalty", "trees",
"sub_neighbors", ".pred_class", "x", "y", "predictor_indicators",
"compute_intercept", "remove_intercept")
"compute_intercept", "remove_intercept", "estimate", "term")
)

# nocov end
6 changes: 3 additions & 3 deletions R/predict.R
Original file line number Diff line number Diff line change
Expand Up @@ -81,12 +81,12 @@
#' lm_model <-
#' linear_reg() %>%
#' set_engine("lm") %>%
#' fit(mpg ~ ., data = mtcars %>% slice(11:32))
#' fit(mpg ~ ., data = mtcars %>% dplyr::slice(11:32))
#'
#' pred_cars <-
#' mtcars %>%
#' slice(1:10) %>%
#' select(-mpg)
#' dplyr::slice(1:10) %>%
#' dplyr::select(-mpg)
#'
#' predict(lm_model, pred_cars)
#'
Expand Down
63 changes: 63 additions & 0 deletions R/tidy_glmnet.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
#' tidy methods for glmnet models
#'
#' `tidy()` methods for the various `glmnet` models that return the coefficients
#' for the specific penalty value used by the `parsnip` model fit.
#' @param x A fitted `parsnip` model that used the `glmnet` engine.
#' @param penalty A _single_ numeric value. If none is given, the value specified
#' in the model specification is used.
#' @param ... Not used
#' @return A tibble with columns `term`, `estimate`, and `penalty`. When a
#' multinomial mode is used, an additional `class` column is included.
#' @importFrom stats coef
#' @export
tidy._elnet <- function(x, penalty = NULL, ...) {
tidy_glmnet(x, penalty)
}

#' @export
#' @rdname tidy._elnet
tidy._lognet <- function(x, penalty = NULL, ...) {
tidy_glmnet(x, penalty)
}

#' @export
#' @rdname tidy._elnet
tidy._multnet <- function(x, penalty = NULL, ...) {
tidy_glmnet(x, penalty)
}

#' @export
#' @rdname tidy._elnet
tidy._fishnet <- function(x, penalty = NULL, ...) {
tidy_glmnet(x, penalty)
}

## -----------------------------------------------------------------------------

get_glmn_coefs <- function(x, penalty = 0.01) {
res <- coef(x, s = penalty)
res <- as.matrix(res)
colnames(res) <- "estimate"
rn <- rownames(res)
res <- tibble::as_tibble(res) %>% mutate(term = rn, penalty = penalty)
res <- dplyr::select(res, term, estimate, penalty)
if (is.list(res$estimate)) {
res$estimate <- purrr::map(res$estimate, ~ as_tibble(as.matrix(.x), rownames = "term"))
res <- tidyr::unnest(res, cols = c(estimate), names_repair = "minimal")
names(res) <- c("class", "term", "estimate", "penalty")
}
res
}

tidy_glmnet <- function(x, penalty = NULL, ...) {
check_installs(x$spec)
load_libs(x$spec, quiet = TRUE, attach = TRUE)
if (is.null(penalty)) {
if (isTRUE(is.numeric(x$spec$args$penalty))){
penalty <- x$spec$args$penalty
} else {
rlang::abort("Please pick a single value of `penalty`.")
}
}
get_glmn_coefs(x$fit, penalty = penalty)
}
4 changes: 4 additions & 0 deletions R/zzz.R
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,10 @@
.onLoad <- function(libname, pkgname) {
s3_register("broom::tidy", "model_fit")
s3_register("broom::tidy", "nullmodel")
s3_register("broom::tidy", "_elnet")
s3_register("broom::tidy", "_lognet")
s3_register("broom::tidy", "_multnet")
s3_register("broom::tidy", "_fishnet")
}


Expand Down
6 changes: 3 additions & 3 deletions man/predict.model_fit.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

33 changes: 33 additions & 0 deletions man/tidy._elnet.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions tests/testthat/test_multinom_reg_glmnet.R
Original file line number Diff line number Diff line change
Expand Up @@ -143,8 +143,8 @@ test_that('glmnet probabilities, mulitiple lambda', {

for (i in seq_along(mult_class_res$.pred)) {
expect_equal(
mult_class %>% slice(i) %>% pull(.pred) %>% purrr::pluck(1) %>% dplyr::select(starts_with(".pred")),
mult_class_res %>% slice(i) %>% pull(.pred) %>% purrr::pluck(1) %>% dplyr::select(starts_with(".pred"))
mult_class %>% dplyr::slice(i) %>% pull(.pred) %>% purrr::pluck(1) %>% dplyr::select(starts_with(".pred")),
mult_class_res %>% dplyr::slice(i) %>% pull(.pred) %>% purrr::pluck(1) %>% dplyr::select(starts_with(".pred"))
)
}

Expand Down
58 changes: 58 additions & 0 deletions tests/testthat/test_tidy_glmnet.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
context("tidy glmnet models")

test_that('linear regression', {
skip_if_not_installed("glmnet")

ps_mod <-
linear_reg(penalty = .1) %>%
set_engine("glmnet") %>%
fit(mpg ~ ., data = mtcars)

ps_coefs <- tidy(ps_mod)
gn_coefs <- as.matrix(coef(ps_mod$fit, s = .1))
for(i in ps_coefs$term) {
expect_equal(ps_coefs$estimate[ps_coefs$term == i], gn_coefs[i,1])
}
})

test_that('logistic regression', {
skip_if_not_installed("glmnet")

data(two_class_dat, package = "modeldata")

ps_mod <-
logistic_reg(penalty = .1) %>%
set_engine("glmnet") %>%
fit(Class ~ ., data = two_class_dat)

ps_coefs <- tidy(ps_mod)
gn_coefs <- as.matrix(coef(ps_mod$fit, s = .1))
for(i in ps_coefs$term) {
expect_equal(ps_coefs$estimate[ps_coefs$term == i], gn_coefs[i,1])
}
})

test_that('multinomial regression', {
skip_if_not_installed("glmnet")

data(penguins, package = "modeldata")

ps_mod <-
multinom_reg(penalty = .01) %>%
set_engine("glmnet") %>%
fit(species ~ ., data = penguins)

ps_coefs <- tidy(ps_mod)
gn_coefs <- coef(ps_mod$fit, s = .01)
gn_coefs <- purrr::map(gn_coefs, as.matrix)
for(i in unique(ps_coefs$term)) {
for(j in unique(ps_coefs$class)) {
expect_equal(
ps_coefs$estimate[ps_coefs$term == i & ps_coefs$class == j],
gn_coefs[[j]][i,1]
)
}
}
})


0 comments on commit a469d0a

Please sign in to comment.