Skip to content

Commit

Permalink
Update Predictor checkpoint structure (#2984)
Browse files Browse the repository at this point in the history
* Update checkpoint structure

* Change to hyphen

* fix

* Use warning instead of warn

---------

Co-authored-by: Abdul Fatir Ansari <[email protected]>
  • Loading branch information
abdulfatir and Abdul Fatir Ansari authored Aug 30, 2023
1 parent 7c4b05d commit 0e6eb32
Showing 1 changed file with 19 additions and 8 deletions.
27 changes: 19 additions & 8 deletions src/gluonts/model/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
if TYPE_CHECKING: # avoid circular import
from gluonts.model.estimator import Estimator # noqa


logger = logging.getLogger(__name__)
OutputTransform = Callable[[DataEntry, np.ndarray], np.ndarray]


Expand Down Expand Up @@ -77,11 +77,14 @@ def predict(self, dataset: Dataset, **kwargs) -> Iterator[Forecast]:

def serialize(self, path: Path) -> None:
# serialize Predictor type
with (path / "type.txt").open("w") as fp:
fp.write(fqname_for(self.__class__))
with (path / "version.json").open("w") as fp:
with (path / "gluonts-config.json").open("w") as fp:
json.dump(
{"model": self.__version__, "gluonts": gluonts.__version__}, fp
{
"model": self.__version__,
"gluonts": gluonts.__version__,
"type": fqname_for(self.__class__),
},
fp,
)

@classmethod
Expand All @@ -99,8 +102,17 @@ def deserialize(cls, path: Path, **kwargs) -> "Predictor":
otherwise.
"""
# deserialize Predictor type
with (path / "type.txt").open("r") as fp:
tpe_str = fp.readline()
if (path / "gluonts-config.json").exists():
with (path / "gluonts-config.json").open("r") as fp:
tpe_str = json.load(fp)["type"]
else:
logger.warning(
"Deserializing an old version of gluonts predictor. "
"Support for old gluonts predictors will be removed in v0.16. "
"Consider serializing this predictor again.",
)
with (path / "type.txt").open("r") as fp:
tpe_str = fp.readline()

tpe = locate(tpe_str)
assert tpe is not None, f"Cannot locate {tpe_str}."
Expand Down Expand Up @@ -364,7 +376,6 @@ def __init__(self, estimator: "Estimator"):
self.estimator = estimator

def predict(self, dataset: Dataset, **kwargs) -> Iterator[Forecast]:
logger = logging.getLogger(__name__)
for i, ts in enumerate(dataset, start=1):
logger.info(f"training for time series {i} / {len(dataset)}")
trained_pred = self.estimator.train([ts])
Expand Down

0 comments on commit 0e6eb32

Please sign in to comment.