diff --git a/pipelines/forecast_state.py b/pipelines/forecast_state.py index c5113117..e5fb1f04 100644 --- a/pipelines/forecast_state.py +++ b/pipelines/forecast_state.py @@ -3,12 +3,13 @@ import os import shutil import subprocess +import tomllib +import tomli_w from datetime import datetime, timedelta from pathlib import Path import numpyro import polars as pl -import yaml from prep_data import process_and_save_state from pygit2 import Repository from save_eval_data import save_eval_data @@ -20,11 +21,11 @@ def record_git_info(model_run_dir: Path): - metadata_file = Path(model_run_dir, "metadata.yaml") + metadata_file = Path(model_run_dir, "metadata.toml") if metadata_file.exists(): - with open(metadata_file, "r") as file: - metadata = yaml.safe_load(file) + with open(metadata_file, "rb") as file: + metadata = tomllib.load(file) else: metadata = {} @@ -43,17 +44,18 @@ def record_git_info(model_run_dir: Path): metadata.update(new_metadata) - with open(metadata_file, "w") as file: - yaml.dump(metadata, file) + metadata_file.parent.mkdir(parents=True, exist_ok=True) + with open(metadata_file, "wb") as file: + tomli_w.dump(metadata, file) def copy_and_record_priors(priors_path: Path, model_run_dir: Path): - metadata_file = Path(model_run_dir, "metadata.yaml") + metadata_file = Path(model_run_dir, "metadata.toml") shutil.copyfile(priors_path, Path(model_run_dir, "priors.py")) if metadata_file.exists(): - with open(metadata_file, "r") as file: - metadata = yaml.safe_load(file) + with open(metadata_file, "rb") as file: + metadata = tomllib.load(file) else: metadata = {} @@ -63,8 +65,8 @@ def copy_and_record_priors(priors_path: Path, model_run_dir: Path): metadata.update(new_metadata) - with open(metadata_file, "w") as file: - yaml.dump(metadata, file) + with open(metadata_file, "wb") as file: + tomli_w.dump(metadata, file) def generate_epiweekly(model_run_dir: Path) -> None: diff --git a/pyproject.toml b/pyproject.toml index e37307d1..e561d840 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -19,6 +19,7 @@ pyarrow = "^18.0.0" pygit2 = "^1.16.0" azuretools = {git = "https://github.com/cdcgov/cfa-azuretools"} forecasttools = {git = "https://github.com/CDCgov/forecasttools-py"} +tomli-w = "^1.1.0" [tool.poetry.group.test] optional = true