diff --git a/R/layer_point_and_distn.R b/R/layer_point_and_distn.R index 24d82d59f..a0c3fe50a 100644 --- a/R/layer_point_and_distn.R +++ b/R/layer_point_and_distn.R @@ -10,7 +10,7 @@ layer_point_and_distn <- function(frosting, trainer, ..., point_id = NULL, point_type = c("median", "mean"), truncate = c(-Inf, Inf), - use_predictive_distribution = FALSE, + use_predictive_distribution = TRUE, dist_type = "gaussian") { rlang::check_dots_empty() stopifnot(inherits(recipe, "recipe")) @@ -43,7 +43,7 @@ layer_point_and_distn <- function(frosting, trainer, ..., if (is.null(distn_id)) { distn_id <- rand_id("residual_quantiles") } - if (use_predictive_distribution) { + if (inherits(trainer, "linear_reg") && use_predictive_distribution) { frosting %<>% layer_residual_quantiles( dist_type = dist_type, name = distn_name,