From 0e6eb32d92327e85c13763678d562403352c13b6 Mon Sep 17 00:00:00 2001 From: Abdul Fatir Date: Wed, 30 Aug 2023 16:40:32 +0200 Subject: [PATCH] Update Predictor checkpoint structure (#2984) * Update checkpoint structure * Change to hyphen * fix * Use warning instead of warn --------- Co-authored-by: Abdul Fatir Ansari --- src/gluonts/model/predictor.py | 27 +++++++++++++++++++-------- 1 file changed, 19 insertions(+), 8 deletions(-) diff --git a/src/gluonts/model/predictor.py b/src/gluonts/model/predictor.py index 9e81ed9c88..0d1ac5e3d1 100644 --- a/src/gluonts/model/predictor.py +++ b/src/gluonts/model/predictor.py @@ -34,7 +34,7 @@ if TYPE_CHECKING: # avoid circular import from gluonts.model.estimator import Estimator # noqa - +logger = logging.getLogger(__name__) OutputTransform = Callable[[DataEntry, np.ndarray], np.ndarray] @@ -77,11 +77,14 @@ def predict(self, dataset: Dataset, **kwargs) -> Iterator[Forecast]: def serialize(self, path: Path) -> None: # serialize Predictor type - with (path / "type.txt").open("w") as fp: - fp.write(fqname_for(self.__class__)) - with (path / "version.json").open("w") as fp: + with (path / "gluonts-config.json").open("w") as fp: json.dump( - {"model": self.__version__, "gluonts": gluonts.__version__}, fp + { + "model": self.__version__, + "gluonts": gluonts.__version__, + "type": fqname_for(self.__class__), + }, + fp, ) @classmethod @@ -99,8 +102,17 @@ def deserialize(cls, path: Path, **kwargs) -> "Predictor": otherwise. """ # deserialize Predictor type - with (path / "type.txt").open("r") as fp: - tpe_str = fp.readline() + if (path / "gluonts-config.json").exists(): + with (path / "gluonts-config.json").open("r") as fp: + tpe_str = json.load(fp)["type"] + else: + logger.warning( + "Deserializing an old version of gluonts predictor. " + "Support for old gluonts predictors will be removed in v0.16. " + "Consider serializing this predictor again.", + ) + with (path / "type.txt").open("r") as fp: + tpe_str = fp.readline() tpe = locate(tpe_str) assert tpe is not None, f"Cannot locate {tpe_str}." @@ -364,7 +376,6 @@ def __init__(self, estimator: "Estimator"): self.estimator = estimator def predict(self, dataset: Dataset, **kwargs) -> Iterator[Forecast]: - logger = logging.getLogger(__name__) for i, ts in enumerate(dataset, start=1): logger.info(f"training for time series {i} / {len(dataset)}") trained_pred = self.estimator.train([ts])