Skip to content

Commit

Permalink
Merge pull request #225 from R4EPI/znk-fix-219
Browse files Browse the repository at this point in the history
Use the apyramid package
  • Loading branch information
zkamvar authored Dec 20, 2019
2 parents 9b44e40 + 24da70c commit 201f2c0
Show file tree
Hide file tree
Showing 26 changed files with 586 additions and 630 deletions.
8 changes: 5 additions & 3 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -64,18 +64,20 @@ Imports:
survey,
srvyr,
stats,
utils
utils,
apyramid
Suggests:
testthat (>= 2.1.0),
sessioninfo,
vdiffr,
covr
Remotes:
reconhub/linelist
reconhub/linelist,
R4EPI/apyramid
Encoding: UTF-8
LazyData: true
Roxygen: list(markdown = TRUE)
RoxygenNote: 6.1.1
RoxygenNote: 7.0.2
Collate:
'add_weights_cluster.R'
'add_weights_strata.R'
Expand Down
5 changes: 3 additions & 2 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ export(tab_survey)
export(tab_univariate)
export(unite_ci)
export(zcurve)
import(ggplot2)
importFrom(apyramid,age_pyramid)
importFrom(dplyr,bind_rows)
importFrom(dplyr,count)
importFrom(dplyr,funs)
Expand All @@ -56,16 +56,17 @@ importFrom(ggplot2,expand_scale)
importFrom(ggplot2,geom_density)
importFrom(ggplot2,ggplot)
importFrom(ggplot2,labs)
importFrom(ggplot2,scale_color_manual)
importFrom(ggplot2,scale_x_continuous)
importFrom(ggplot2,scale_y_continuous)
importFrom(ggplot2,stat_density)
importFrom(ggplot2,stat_function)
importFrom(glue,glue)
importFrom(rio,import)
importFrom(rlang,"!!")
importFrom(rlang,".data")
importFrom(rlang,":=")
importFrom(rlang,sym)
importFrom(scales,percent)
importFrom(scales,percent_format)
importFrom(sf,st_intersection)
importFrom(sf,st_make_grid)
Expand Down
246 changes: 14 additions & 232 deletions R/age-pyramid.R
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,7 @@
#' indicating the range of the un-facetted data set. Values of `spit_by` will
#' show up as labels at top of each facet.
#'
#' @import ggplot2
#' @importFrom scales percent
#' @importFrom apyramid age_pyramid
#' @export
#' @examples
#' library(ggplot2)
Expand Down Expand Up @@ -98,239 +97,22 @@
#' plot_age_pyramid(dat3, age_group = AGE)
#' theme_set(old)
plot_age_pyramid <- function(data, age_group = "age_group", split_by = "sex",
stack_by = split_by, proportional = FALSE, na.rm = FALSE,
stack_by = NULL, proportional = FALSE, na.rm = FALSE,
show_halfway = TRUE, vertical_lines = FALSE,
horizontal_lines = TRUE, pyramid = TRUE,
pal = NULL) {

is_df <- is.data.frame(data)
is_svy <- inherits(data, "tbl_svy")
age_group <- tidyselect::vars_select(colnames(data), !! enquo(age_group))
split_by <- tidyselect::vars_select(colnames(data), !! enquo(split_by))
stack_by <- tidyselect::vars_select(colnames(data), !! enquo(stack_by))

if (!is_df && !is_svy) {
msg <- sprintf("%s must be a data frame or object", deparse(substitute(data)))
stop(msg)
}

ag <- rlang::sym(age_group)
sb <- rlang::sym(split_by)
st <- rlang::sym(stack_by)

# Count the plot data --------------------------------------------------------
plot_data <- count_age_categories(data,
age_group,
split_by,
stack_by,
proportional,
na.rm)

# gathering the levels for each of the elements ------------------------------
age_levels <- levels(plot_data[[age_group]])
max_age_group <- age_levels[length(age_levels)]

# Splitting levels without missing data
split_levels <- plot_data[[split_by]]
split_levels <- if (is.factor(split_levels)) levels(split_levels) else unique(split_levels)
split_levels <- split_levels[!is.na(split_levels)]

# Stacking levels assuming there is no missing data
stk_levels <- plot_data[[stack_by]]
stk_levels <- if (is.factor(stk_levels)) levels(stk_levels) else unique(stk_levels)

stopifnot(length(split_levels) >= 1L)

# Switch between pyramid and non-pyramid shape -------------------------------
# This will only result in a pyramid if the user specifies so AND the split
# levels is binary.
split_measured_binary <- pyramid && length(split_levels) == 2L

if (split_measured_binary) {
maxdata <- dplyr::group_by(plot_data, !!ag, !!sb, .drop = FALSE)
} else {
maxdata <- dplyr::group_by(plot_data, !!ag, .drop = FALSE)
}

# find the maximum x axis position
maxdata <- dplyr::tally(maxdata, wt = !! quote(n))
max_n <- max(abs(maxdata[["n"]]), na.rm = TRUE)

# make sure the x axis is a multiple of ten. This took a lot of fiddling
if (proportional) {
max_n <- ceiling(max_n * 100)
max_n <- max_n + if (max_n %% 10 == 0) 0 else (10 - max_n %% 10)
step_size <- if (max_n > 25) 0.1 else if (max_n > 15) 0.05 else 0.01
max_n <- max_n / 100
lab_fun <- function(i) scales::percent(abs(i))
y_lab <- "proportion"
} else {
max_n <- max_n + if (max_n %% 10 == 0) 0 else (10 - max_n %% 10)
step_size <- ceiling(max_n / 5)
lab_fun <- abs
y_lab <- "counts"
}

stopifnot(is.finite(max_n), max_n > 0)

# Make sure the breaks are correct for the plot size
the_breaks <- seq(0, max_n, step_size)
the_breaks <- if (split_measured_binary) c(-rev(the_breaks[-1]), the_breaks) else the_breaks


if (split_measured_binary) {
# If the user has a binary level and wants to plot the data in a pyramid,
# then we need to make the counts for the primary level negative so that
# they appear to go to the left on the plot.
plot_data[["n"]] <- ifelse(plot_data[[split_by]] == split_levels[[1L]], -1L, 1L) * plot_data[["n"]]
maxdata[["d"]] <- ifelse(maxdata[[split_by]] == split_levels[[1L]], -1L, 0L) * maxdata[["n"]]
# If we are labelling the center, then we can get it by summing the values over the age groups
maxdata <- dplyr::summarise(maxdata, center = sum(!! quote(d)) + sum(!! quote(n)) / 2)
} else {
maxdata[["center"]] <- maxdata[["n"]] / 2
}

# Create base plot -----------------------------------------------------------
pyramid <- ggplot(plot_data, aes(x = !!ag, y = !!quote(n))) +
theme(axis.line.y = element_blank()) +
labs(y = y_lab)
pal <- if (is.function(pal)) pal(length(stk_levels)) else pal

if (!split_measured_binary) {
# add the background layer if the split is not binary
maxdata[["zzzzz_alpha"]] <- "Total"
pyramid <- pyramid +
geom_col(aes(alpha = !! quote(zzzzz_alpha)), fill = "grey80", color = "grey20", data = maxdata)
}

# Add bars, scales, and themes -----------------------------------------------
pyramid <- pyramid +
geom_col(aes(group = !!sb, fill = !!st), color = "grey20") +
coord_flip()
if (is.null(pal)) {
pyramid <- pyramid + scale_fill_brewer(type = "qual", guide = guide_legend(order = 1))
} else {
pyramid <- pyramid + scale_fill_manual(values = pal, guide = guide_legend(order = 1))
}
pyramid <- pyramid +
scale_y_continuous(limits = if (split_measured_binary) c(-max_n, max_n) else c(0, max_n),
breaks = the_breaks,
labels = lab_fun) +
scale_x_discrete(drop = FALSE) # note: drop = FALSE important to avoid missing age groups

if (!split_measured_binary) {
# Wrap the categories if the split is not binary
pyramid <- pyramid +
facet_wrap(split_by) +
scale_alpha_manual(values = 0.5, guide = guide_legend(title = NULL, order = 3))
}
if (vertical_lines == TRUE) {
pyramid <- pyramid +
geom_hline(yintercept = c(seq(-max_n, max_n, step_size)), linetype = "dotted", colour = "grey")
}


if (show_halfway) {
maxdata <- dplyr::arrange(maxdata, !! ag)
maxdata[['x']] <- seq_along(maxdata[[age_group]]) - 0.25
maxdata[['xend']] <- maxdata[['x']] + 0.5
maxdata[['halfway']] <- 'midpoint'
pyramid <- pyramid +
geom_segment(aes(x = !! quote(x),
xend = !! quote(xend),
y = !! quote(center),
yend = !! quote(center),
linetype = !! quote(halfway)),
color = "grey20",
key_glyph = "vpath", # NOTE: key_glyph is only part of ggplot2 >= 2.3.0; this will warn otherwise
data = maxdata) +
scale_linetype_manual(values = 'dashed', guide = guide_legend(title = NULL, order = 2))

}

if (split_measured_binary && stack_by != split_by) {
# If the split is binary and we have both stacked and split data, then we
# need to label the groups. We do so by adding a label annotation
pyramid <- pyramid +
annotate(geom = "label",
x = max_age_group,
y = -step_size,
vjust = 0.5,
hjust = 1,
label = split_levels[[1]]) +
annotate(geom = "label",
x = max_age_group,
y = step_size,
vjust = 0.5,
hjust = 0,
label = split_levels[[2]])
}

if (horizontal_lines == TRUE) {
pyramid <- pyramid + theme(panel.grid.major.y = element_line(linetype = 2))
}

pyramid <- pyramid +
geom_hline(yintercept = 0) # add vertical line

pyramid
}


# This will count the age categories for us. I've pulled it out of the plot-age
count_age_categories <- function(data, age_group, split_by, stack_by, proportional, na.rm) {

ag <- rlang::sym(age_group)
sb <- rlang::sym(split_by)
st <- rlang::sym(stack_by)

if (is.data.frame(data)) {
sbv <- data[[split_by]]
stv <- data[[stack_by]]
if (!is.character(sbv) || !is.factor(sbv)) {
data[[split_by]] <- as.character(sbv)
}
if (!is.character(stv) || !is.factor(stv)) {
data[[stack_by]] <- as.character(stv)
}
if (anyNA(sbv) || anyNA(stv)) {
nas <- is.na(sbv) | is.na(stv)
warning(sprintf("removing %d observations with missing values between the %s and %s columns.",
sum(nas), split_by, stack_by))
data <- data[!nas, , drop = FALSE]
}
if (na.rm) {
nas <- is.na(data[[age_group]])
warning(sprintf("removing %d observations with missing values from the %s column.",
sum(nas), age_group))
data <- data[!nas, , drop = FALSE]
} else {
data[[age_group]] <- forcats::fct_explicit_na(data[[age_group]])
}
# plot_data <- tidyr::complete(data, !!ag) # make sure all factors are represented
plot_data <- dplyr::group_by(data, !!ag, !!sb, !!st, .drop = FALSE)
plot_data <- dplyr::summarise(plot_data, n = dplyr::n())
plot_data <- dplyr::ungroup(plot_data)
if (is.factor(sbv)) {
plot_data[[split_by]] <- factor(plot_data[[split_by]], levels(sbv))
}
if (is.factor(stv)) {
plot_data[[stack_by]] <- factor(plot_data[[stack_by]], levels(stv))
}
} else {
plot_data <- srvyr::group_by(data, !!ag, !!sb, !!st, .drop = FALSE)
plot_data <- srvyr::summarise(plot_data,
n = srvyr::survey_total(vartype = "ci", level = 0.95))

}
# Remove any missing values
to_delete <- is.na(plot_data[[split_by]]) & is.na(plot_data[[stack_by]]) & plot_data[["n"]] == 0
plot_data <- plot_data[!to_delete, , drop = FALSE]

if (proportional) {
plot_data$n <- plot_data$n / sum(plot_data$n, na.rm = TRUE)
}

plot_data
age_pyramid(data,
age_group = {{age_group}},
split_by = {{split_by}},
stack_by = {{stack_by}},
proportional = proportional,
na.rm = na.rm,
show_midpoint = show_halfway,
vertical_lines = vertical_lines,
horizontal_lines = horizontal_lines,
pyramid = pyramid,
pal = pal
)

}
4 changes: 2 additions & 2 deletions R/descriptive-table.R
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,8 @@ descriptive <- function(df, counter, grouper = NULL, multiplier = 100, digits =


# translate the variable names to character
counter <- tidyselect::vars_select(colnames(df), !!enquo(counter))
grouper <- tidyselect::vars_select(colnames(df), !!enquo(grouper))
counter <- tidyselect::vars_select(colnames(df), !!rlang::enquo(counter))
grouper <- tidyselect::vars_select(colnames(df), !!rlang::enquo(grouper))

has_grouper <- length(grouper) == 1
sym_count <- rlang::sym(counter)
Expand Down
2 changes: 1 addition & 1 deletion R/zcurve.R
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
#' @param zscore bare name of a numeric vector containing computed zscores
#' @return a ggplot2 object that is customisable via the ggplot2 package.
#' @export
#' @importFrom ggplot2 ggplot aes stat_function geom_density scale_x_continuous scale_y_continuous labs expand_scale
#' @importFrom ggplot2 ggplot aes stat_function geom_density scale_x_continuous scale_y_continuous labs expand_scale stat_density scale_color_manual
#' @importFrom scales percent_format
#' @examples
#' library("ggplot2")
Expand Down
18 changes: 14 additions & 4 deletions man/add_weights_cluster.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

10 changes: 8 additions & 2 deletions man/add_weights_strata.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading

0 comments on commit 201f2c0

Please sign in to comment.