Skip to content

Commit

Permalink
Merge branch '73-update-layer' of https://github.com/cmu-delphi/epipr…
Browse files Browse the repository at this point in the history
…edict into 73-update-layer
  • Loading branch information
rachlobay committed Sep 16, 2023
2 parents 572cc70 + 074d979 commit 2b49c2e
Show file tree
Hide file tree
Showing 121 changed files with 1,897 additions and 4,139 deletions.
3 changes: 3 additions & 0 deletions .git-blame-ignore-revs
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# using styler at all
aca7d5e7b66d8bac9d9fbcec3acdb98a087d58fa
f12fcc2bf3fe0a75ba2b10eaaf8a1f1d22486a17
81 changes: 81 additions & 0 deletions .github/workflows/styler.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
# Workflow derived from https://github.com/r-lib/actions/tree/v2/examples
# Need help debugging build failures? Start at https://github.com/r-lib/actions#where-to-find-help
on:
workflow_dispatch:
pullrequest:
paths:
[
"**.[rR]",
"**.[qrR]md",
"**.[rR]markdown",
"**.[rR]nw",
"**.[rR]profile",
]

name: Style

jobs:
style:
runs-on: ubuntu-latest
env:
GITHUB_PAT: ${{ secrets.GITHUB_TOKEN }}
steps:
- name: Checkout repo
uses: actions/checkout@v3
with:
fetch-depth: 0

- name: Setup R
uses: r-lib/actions/setup-r@v2
with:
use-public-rspm: true

- name: Install dependencies
uses: r-lib/actions/setup-r-dependencies@v2
with:
extra-packages: any::styler, any::roxygen2
needs: styler

- name: Enable styler cache
run: styler::cache_activate()
shell: Rscript {0}

- name: Determine cache location
id: styler-location
run: |
cat(
"location=",
styler::cache_info(format = "tabular")$location,
"\n",
file = Sys.getenv("GITHUB_OUTPUT"),
append = TRUE,
sep = ""
)
shell: Rscript {0}

- name: Cache styler
uses: actions/cache@v3
with:
path: ${{ steps.styler-location.outputs.location }}
key: ${{ runner.os }}-styler-${{ github.sha }}
restore-keys: |
${{ runner.os }}-styler-
${{ runner.os }}-
- name: Style
run: styler::style_pkg()
shell: Rscript {0}

- name: Commit and push changes
run: |
if FILES_TO_COMMIT=($(git diff-index --name-only ${{ github.sha }} \
| egrep --ignore-case '\.(R|[qR]md|Rmarkdown|Rnw|Rprofile)$'))
then
git config --local user.name "$GITHUB_ACTOR"
git config --local user.email "[email protected]"
git commit ${FILES_TO_COMMIT[*]} -m "Style code (GHA)"
git pull --ff-only
git push origin
else
echo "No changes to commit."
fi
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ Imports:
Suggests:
covidcast,
data.table,
epidatr,
epidatr (>= 1.0.0),
ggplot2,
knitr,
lubridate,
Expand Down
68 changes: 36 additions & 32 deletions R/arx_classifier.R
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,9 @@ arx_classifier <- function(
predictors,
trainer = parsnip::logistic_reg(),
args_list = arx_class_args_list()) {

if (!is_classification(trainer))
if (!is_classification(trainer)) {
cli::cli_abort("`trainer` must be a {.pkg parsnip} model of mode 'classification'.")
}

wf <- arx_class_epi_workflow(
epi_data, outcome, predictors, trainer, args_list
Expand All @@ -65,13 +65,15 @@ arx_classifier <- function(
tibble::as_tibble() %>%
dplyr::select(-time_value)

structure(list(
predictions = preds,
epi_workflow = wf,
metadata = list(
training = attr(epi_data, "metadata"),
forecast_created = Sys.time()
)),
structure(
list(
predictions = preds,
epi_workflow = wf,
metadata = list(
training = attr(epi_data, "metadata"),
forecast_created = Sys.time()
)
),
class = c("arx_class", "canned_epipred")
)
}
Expand Down Expand Up @@ -117,12 +119,13 @@ arx_class_epi_workflow <- function(
predictors,
trainer = NULL,
args_list = arx_class_args_list()) {

validate_forecaster_inputs(epi_data, outcome, predictors)
if (!inherits(args_list, c("arx_class", "alist")))
if (!inherits(args_list, c("arx_class", "alist"))) {
rlang::abort("args_list was not created using `arx_class_args_list().")
if (!(is.null(trainer) || is_classification(trainer)))
}
if (!(is.null(trainer) || is_classification(trainer))) {
rlang::abort("`trainer` must be a `{parsnip}` model of mode 'classification'.")
}
lags <- arx_lags_validator(predictors, args_list$lags)

# --- preprocessor
Expand Down Expand Up @@ -172,8 +175,10 @@ arx_class_epi_workflow <- function(
o2 <- rlang::sym(paste0("ahead_", args_list$ahead, "_", o))
r <- r %>%
step_epi_ahead(!!o, ahead = args_list$ahead, role = "pre-outcome") %>%
step_mutate(outcome_class = cut(!!o2, breaks = args_list$breaks),
role = "outcome") %>%
step_mutate(
outcome_class = cut(!!o2, breaks = args_list$breaks),
role = "outcome"
) %>%
step_epi_naomit() %>%
step_training_window(n_recent = args_list$n_training)

Expand Down Expand Up @@ -245,9 +250,7 @@ arx_class_args_list <- function(
method = c("rel_change", "linear_reg", "smooth_spline", "trend_filter"),
log_scale = FALSE,
additional_gr_args = list(),
nafill_buffer = Inf
) {

nafill_buffer = Inf) {
.lags <- lags
if (is.list(lags)) lags <- unlist(lags)
method <- match.arg(method)
Expand All @@ -266,7 +269,8 @@ arx_class_args_list <- function(
cli::cli_abort(
c("`additional_gr_args` must be a {.cls list}.",
"!" = "This is a {.cls {class(additional_gr_args)}}.",
i = "See `?epiprocess::growth_rate` for available arguments.")
i = "See `?epiprocess::growth_rate` for available arguments."
)
)
}

Expand All @@ -277,19 +281,20 @@ arx_class_args_list <- function(

max_lags <- max(lags)
structure(
enlist(lags = .lags,
ahead,
n_training,
breaks,
forecast_date,
target_date,
outcome_transform,
max_lags,
horizon,
method,
log_scale,
additional_gr_args,
nafill_buffer
enlist(
lags = .lags,
ahead,
n_training,
breaks,
forecast_date,
target_date,
outcome_transform,
max_lags,
horizon,
method,
log_scale,
additional_gr_args,
nafill_buffer
),
class = c("arx_class", "alist")
)
Expand All @@ -300,4 +305,3 @@ print.arx_class <- function(x, ...) {
name <- "ARX Classifier"
NextMethod(name = name, ...)
}

Loading

0 comments on commit 2b49c2e

Please sign in to comment.