Skip to content

Commit

Permalink
update from June codebase (#26)
Browse files Browse the repository at this point in the history
  • Loading branch information
kaitejohnson authored Jun 5, 2024
1 parent 3ba0210 commit fa10fe0
Show file tree
Hide file tree
Showing 72 changed files with 2,752 additions and 529 deletions.
55 changes: 55 additions & 0 deletions .github/actions/install-cmdstan/action.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
name: 'Install CmdStan with caching'
description: 'Install CmdStan with caching'
inputs:
cmdstan-version:
description: 'CmdStan version to install (use "latest" for the latest version)'
required: false
default: 'latest'
num-cores:
description: 'Number of cores to use for building CmdStan'
required: false
default: '1'

runs:
using: 'composite'
steps:
- name: Determine CmdStan Version (Unix)
if: runner.os != 'Windows' && inputs.cmdstan-version == 'latest'
run: |
chmod +x ${{ github.action_path }}/scripts/get-latest-release.sh
${{ github.action_path }}/scripts/get-latest-release.sh
shell: bash

- name: Determine CmdStan Version (Windows)
if: runner.os == 'Windows' && inputs.cmdstan-version == 'latest'
run: ${{ github.action_path }}\scripts\get-latest-release.ps1
shell: pwsh

- name: Set CmdStan Version (Specified)
if: inputs.cmdstan-version != 'latest'
run: echo "CMDSTAN_VERSION=${{ inputs.cmdstan-version }}" >> $GITHUB_ENV
shell: bash

- name: Restore Cache
id: cache-cmdstan
uses: actions/cache@v4
with:
path: '~/.cmdstan/cmdstan-${{ env.CMDSTAN_VERSION }}'
key: ${{ runner.os }}-cmdstan-${{ env.CMDSTAN_VERSION }}

- name: Check the CmdStan toolchain and repair it if required
if: steps.cache-cmdstan.outputs.cache-hit != 'true'
run: |
Rscript -e 'cmdstanr::check_cmdstan_toolchain(fix = TRUE)'
shell: bash

- name: Install CmdStan using cmdstanr
if: steps.cache-cmdstan.outputs.cache-hit != 'true'
run: |
Rscript -e 'cmdstanr::install_cmdstan(version = "${{ env.CMDSTAN_VERSION }}", cores = ${{ inputs.num-cores }})'
shell: bash

- name: Set Cmdstan path
run: |
Rscript -e 'cmdstanr::set_cmdstan_path("~/.cmdstan/cmdstan-${{ env.CMDSTAN_VERSION }}")'
shell: bash
18 changes: 14 additions & 4 deletions .github/actions/install-cmdstan/scripts/get-latest-release.ps1
Original file line number Diff line number Diff line change
@@ -1,26 +1,36 @@
# PowerShell script to fetch the latest CmdStan release version with retry logic
# PowerShell script to fetch the latest CmdStan release version with enhanced retry logic

# Initialize retry parameters
$max_attempts = 5
$wait_time = 5 # seconds
$version = $null

for ($attempt = 1; $attempt -le $max_attempts; $attempt++) {
try {
$response = Invoke-RestMethod -Uri "https://api.github.com/repos/stan-dev/cmdstan/releases/latest" -ErrorAction Stop
$version = $response.tag_name -replace '^v', ''
# Using Invoke-RestMethod to fetch the latest release data
$response = Invoke-RestMethod -Uri "https://api.github.com/repos/stan-dev/cmdstan/releases/latest" -Method Get -ErrorAction Stop
$version = $response.tag_name -replace '^v', '' # Remove 'v' from version if present

# Check if the version is successfully retrieved
if (-not [string]::IsNullOrWhiteSpace($version)) {
"CMDSTAN_VERSION=$version" | Out-File -Append -FilePath $env:GITHUB_ENV
Write-Host "CmdStan latest version: $version"
break
}
} catch {
# Handle different types of errors
if ($_.Exception.Response) {
$statusCode = $_.Exception.Response.StatusCode.value__
Write-Host "HTTP status code: $statusCode"
}

Write-Host "Attempt $attempt of $max_attempts failed. Retrying in $wait_time seconds..."
Start-Sleep -Seconds $wait_time
$wait_time = $wait_time * 2
$wait_time = $wait_time * 2 # Exponential backoff
}
}

# Check if the version was never set and handle the failure
if ([string]::IsNullOrWhiteSpace($version)) {
Write-Host "Failed to fetch CmdStan version after $max_attempts attempts."
exit 1
Expand Down
55 changes: 29 additions & 26 deletions .github/actions/install-cmdstan/scripts/get-latest-release.sh
Original file line number Diff line number Diff line change
Expand Up @@ -11,32 +11,35 @@ else
exit 1
fi

# Function to get the latest CmdStan version using GitHub API
get_latest_version() {
local retries=3
local wait_time=5
local status=0
local version=""

for ((i=0; i<retries; i++)); do
version=$(curl -s https://api.github.com/repos/stan-dev/cmdstan/releases/latest | jq -r '.tag_name' | tr -d 'v')
status=$?
if [ $status -eq 0 ] && [ -n "$version" ]; then
echo $version
return 0
fi
sleep $wait_time
wait_time=$((wait_time*2))
done

return 1
}

# Fetch the latest release version of CmdStan
version=$(get_latest_version)

if [ $? -ne 0 ] || [ -z "$version" ]; then
echo "Failed to fetch the latest CmdStan version"
retries=3
wait_time=5
status=0
version=""

for ((i=0; i<retries; i++)); do
echo "Attempt $((i+1)) of $retries"
# Save the response body to a temporary file and capture HTTP status code separately
response=$(curl -s -w "%{http_code}" -o temp.json https://api.github.com/repos/stan-dev/cmdstan/releases/latest)
http_code=$(echo $response | tail -n1) # Extract the HTTP status code
version=$(jq -r '.tag_name' temp.json | tr -d 'v')
rm temp.json
echo "HTTP status code: $http_code"
echo "Fetched version: $version"

if [[ $http_code == 200 ]] && [ -n "$version" ]; then
echo "Successfully fetched version: $version"
break
else
echo "Failed to fetch version or bad HTTP status. HTTP status: $http_code, Version fetched: '$version'"
fi

sleep $wait_time
echo "Sleeping for $wait_time seconds before retrying..."
wait_time=$((wait_time*2))
done

if [ $status -ne 0 ] || [ -z "$version" ]; then
echo "Failed to fetch the latest CmdStan version after $retries attempts."
exit 1
fi

Expand Down
1 change: 1 addition & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ repos:
- id: check-toml
- id: end-of-file-fixer
- id: mixed-line-ending
args: ['--fix=lf']
- id: trailing-whitespace
#####
# R
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ See our [model definition page](model_definition.md) for further details on the
To run our code, you will need a working installation of [R](https://www.r-project.org/) (version `4.3.0` or later). You can find instructions for installing R on the official [R project website](https://www.r-project.org/).

## Install `cmdstanr` and `CmdStan`
We do inference from our models using [`CmdStan`](https://mc-stan.org/users/interfaces/cmdstan) (version `2.34.1` or later) via its R interface [`cmdstanr`](https://mc-stan.org/cmdstanr/) (version `0.7.1` or later).
We do inference from our models using [`CmdStan`](https://mc-stan.org/users/interfaces/cmdstan) (version `2.35.0` or later) via its R interface [`cmdstanr`](https://mc-stan.org/cmdstanr/) (version `0.8.0` or later).

Open an R session and run the following command to install `cmdstanr` per that package's [official installation guide](https://mc-stan.org/cmdstanr/#installation).

Expand Down
11 changes: 11 additions & 0 deletions _targets.R
Original file line number Diff line number Diff line change
Expand Up @@ -402,6 +402,17 @@ list(
iteration = "list",
deployment = "main"
),
tar_target(
name = plot_autoreg_rt_site,
command = get_plot_single_param(
grouped_df_id,
param_name = "autoreg_rt_site",
figure_output_subdirectory
),
pattern = map(grouped_df_id),
iteration = "list",
deployment = "main"
),
tar_target(
name = plot_p_hosp_mean,
command = get_plot_single_param(
Expand Down
65 changes: 55 additions & 10 deletions _targets_eval.R
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ setup_interactive_dev_run <- function() {
tar_option_set(
packages = c(
"cmdstanr",
"rlang",
"tibble",
"ggplot2",
"dplyr",
Expand Down Expand Up @@ -105,6 +106,10 @@ upstream_targets <- list(
last_hosp_data_date = eval_config$eval_date,
ww_data_mapping = eval_config$ww_data_mapping
)
),
tar_target(
name = table_of_exclusions,
command = tibble::as_tibble(eval_config$table_of_exclusions)
)
)

Expand Down Expand Up @@ -137,24 +142,33 @@ mapped_ww <- tar_map(
priority = 1
),
tar_target(
name = input_hosp_data,
name = raw_input_hosp_data,
command = get_input_hosp_data(forecast_date, location,
hosp_data_dir = eval_config$hosp_data_dir,
calibration_time = eval_config$calibration_time
),
deployment = "main",
priority = 1
),
tar_target(
name = input_hosp_data,
command = exclude_hosp_outliers(
raw_input_hosp_data = raw_input_hosp_data,
forecast_date = forecast_date,
table_of_exclusions = table_of_exclusions
)
),
tar_target(
name = last_hosp_data_date,
command = get_last_hosp_data_date(input_hosp_data),
deployment = "main",
priority = 1
),
tar_target(input_ww_data,
command = get_input_ww_data(forecast_date,
location,
scenario,
command = get_input_ww_data(
forecast_date = forecast_date,
location = location,
scenario = scenario,
scenario_dir = eval_config$scenario_dir,
ww_data_dir = eval_config$ww_data_dir,
calibration_time = eval_config$calibration_time,
Expand All @@ -169,8 +183,11 @@ mapped_ww <- tar_map(
name = standata,
command = get_stan_data_list(
model_type = "ww",
forecast_date, eval_config$forecast_time,
input_ww_data, input_hosp_data,
forecast_date = forecast_date,
forecast_time = eval_config$forecast_time,
calibration_time = eval_config$calibration_time,
input_ww_data = input_ww_data,
input_hosp_data = input_hosp_data,
generation_interval = eval_config$generation_interval,
inf_to_hosp = eval_config$inf_to_hosp,
infection_feedback_pmf = eval_config$infection_feedback_pmf,
Expand Down Expand Up @@ -484,13 +501,21 @@ mapped_hosp <- tar_map(
deployment = "main"
),
tar_target(
name = input_hosp_data,
name = raw_input_hosp_data,
command = get_input_hosp_data(forecast_date, location,
hosp_data_dir = eval_config$hosp_data_dir,
calibration_time = eval_config$calibration_time
),
deployment = "main"
),
tar_target(
name = input_hosp_data,
command = exclude_hosp_outliers(
raw_input_hosp_data = raw_input_hosp_data,
forecast_date = forecast_date,
table_of_exclusions = table_of_exclusions
)
),
tar_target(
name = last_hosp_data_date,
command = get_last_hosp_data_date(input_hosp_data),
Expand All @@ -501,13 +526,15 @@ mapped_hosp <- tar_map(
name = standata,
command = get_stan_data_list(
model_type = "hosp",
forecast_date, eval_config$forecast_time,
forecast_date = forecast_date,
forecast_time = eval_config$forecast_time,
calibration_time = eval_config$calibration_time,
input_ww_data = NA,
input_hosp_data = input_hosp_data,
generation_interval = eval_config$generation_interval,
inf_to_hosp = eval_config$inf_to_hosp,
infection_feedback_pmf = eval_config$infection_feedback_pmf,
params
params = params
),
deployment = "main"
),
Expand Down Expand Up @@ -746,7 +773,7 @@ downstream_targets <- list(
# a mix of model types
tar_target(
name = summarized_raw_scores,
command = scoringutils::summarize_scores(all_raw_scores,
command = scoringutils::summarize_scores(all_ww_scores,
by = c(
"scenario",
"period",
Expand Down Expand Up @@ -800,6 +827,24 @@ downstream_targets <- list(
)
),

## Summary score for model feature comparison-------------------------
tar_target(
name = baseline_score,
command = make_baseline_score_table(
mock_submission_scores |> dplyr::filter(scenario == "status_quo"),
baseline_score_table_dir = eval_config$baseline_score_table_dir,
overwrite_table = eval_config$overwrite_summary_table
)
),
tar_target(
name = baseline_score_hosp,
command = make_baseline_score_table(
mock_submission_scores |> dplyr::filter(scenario == "no_wastewater"),
baseline_score_table_dir = eval_config$baseline_score_table_dir,
overwrite_table = eval_config$overwrite_summary_table
)
),

## Plots----------------------------------------------------
tar_target(
name = plot_raw_scores,
Expand Down
Loading

0 comments on commit fa10fe0

Please sign in to comment.