Skip to content

Commit

Permalink
Track git commit id when fitting models (#242)
Browse files Browse the repository at this point in the history
* user git commit sha as build arg for container

* change SHA reference

* record git info

* add branch name to container

* extract branch name in correct github action

* try again

* minor tweaks

* add debug step

* more debugging

* better handling when no git repo is present

* Remove debug steps
  • Loading branch information
damonbayer authored Dec 16, 2024
1 parent a69fa0c commit a45085c
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 1 deletion.
5 changes: 4 additions & 1 deletion .github/workflows/containers.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ jobs:
outputs:
tag: ${{ steps.image-tag.outputs.tag }}
commit-msg: ${{ steps.commit-message.outputs.message }}
branch: ${{ steps.branch-name.outputs.branch }}

steps:

Expand Down Expand Up @@ -115,7 +116,9 @@ jobs:
with:
push: true # This can be toggled manually for tweaking.
tags: |
${{ env.REGISTRY}}/${{ env.IMAGE_NAME }}:${{ needs.build-dependencies-image.outputs.tag }}
${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}:${{ needs.build-dependencies-image.outputs.tag }}
file: ./Containerfile
build-args: |
TAG=${{ needs.build-dependencies-image.outputs.tag }}
GIT_COMMIT_SHA=${{ github.event.pull_request.head.sha || github.sha }}
GIT_BRANCH_NAME=${{ needs.build-dependencies-image.outputs.branch }}
5 changes: 5 additions & 0 deletions Containerfile
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
ARG TAG=latest

FROM cfaprdbatchcr.azurecr.io/pyrenew-hew-dependencies:${TAG}
ARG GIT_COMMIT_SHA
ENV GIT_COMMIT_SHA=$GIT_COMMIT_SHA

ARG GIT_BRANCH_NAME
ENV GIT_BRANCH_NAME=$GIT_BRANCH_NAME

COPY ./hewr /pyrenew-hew/hewr

Expand Down
26 changes: 26 additions & 0 deletions pipelines/forecast_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@

import numpyro
import polars as pl
import yaml
from prep_data import process_and_save_state
from pygit2 import Repository
from save_eval_data import save_eval_data

numpyro.set_host_device_count(4)
Expand All @@ -17,6 +19,27 @@
from generate_predictive import generate_and_save_predictions # noqa


def record_git_info(model_run_dir: Path):
metadata_file = Path(model_run_dir, "metadata.yaml")
try:
repo = Repository(os.getcwd())
branch_name = os.environ.get(
"GIT_BRANCH_NAME", Path(repo.head.name).stem
)
commit_sha = os.environ.get("GIT_COMMIT_SHA", str(repo.head.target))
except:
branch_name = os.environ.get("GIT_BRANCH_NAME", "unknown")
commit_sha = os.environ.get("GIT_COMMIT_SHA", "unknown")

metadata = {
"branch_name": branch_name,
"commit_sha": commit_sha,
}

with open(metadata_file, "w") as file:
yaml.dump(metadata, file)


def generate_epiweekly(model_run_dir: Path) -> None:
result = subprocess.run(
[
Expand Down Expand Up @@ -238,6 +261,9 @@ def main(

os.makedirs(model_run_dir, exist_ok=True)

logger.info("Recording git info...")
record_git_info(model_run_dir)

logger.info(f"Using priors from {priors_path}...")
shutil.copyfile(priors_path, Path(model_run_dir, "priors.py"))

Expand Down

0 comments on commit a45085c

Please sign in to comment.