From 8cfc8a93256877a988c3d694465fe2b759d4c5e3 Mon Sep 17 00:00:00 2001 From: Pedro Eduardo Mercado Lopez Date: Thu, 30 Nov 2023 11:11:40 +0100 Subject: [PATCH] add validation option to TrainDataset --- src/gluonts/dataset/common.py | 26 ++++++++++++++++++++++++-- 1 file changed, 24 insertions(+), 2 deletions(-) diff --git a/src/gluonts/dataset/common.py b/src/gluonts/dataset/common.py index 8bde27dd2d..26a1a00d49 100644 --- a/src/gluonts/dataset/common.py +++ b/src/gluonts/dataset/common.py @@ -73,11 +73,12 @@ class SourceContext(NamedTuple): class TrainDatasets(NamedTuple): """ A dataset containing two subsets, one to be used for training purposes, and - the other for testing purposes, as well as metadata. + the other for validation and testing purposes, as well as metadata. """ metadata: MetaData train: Dataset + validation: Optional[Dataset] = None test: Optional[Dataset] = None def save( @@ -114,6 +115,11 @@ def save( test.mkdir(parents=True) writer.write_to_folder(self.test, test) + if self.validation is not None: + validation = path / "validation" + validation.mkdir(parents=True) + writer.write_to_folder(self.validation, validation) + def infer_file_type(path): suffix = "".join(path.suffixes) @@ -427,6 +433,7 @@ def __call__(self, data: DataEntry) -> DataEntry: def load_datasets( metadata: Path, train: Path, + validation: Optional[Path], test: Optional[Path], one_dim_target: bool = True, cache: bool = False, @@ -442,6 +449,8 @@ def load_datasets( Path to the training dataset files. test Path to the test dataset files. + validation + Path to the validation dataset files. one_dim_target Whether to load FileDatasets as univariate target time series. cache @@ -467,4 +476,17 @@ def load_datasets( else None ) - return TrainDatasets(metadata=meta, train=train_ds, test=test_ds) + validation_ds = ( + FileDataset( + path=validation, + freq=meta.freq, + one_dim_target=one_dim_target, + cache=cache, + ) + if validation + else None + ) + + return TrainDatasets( + metadata=meta, train=train_ds, validation=validation_ds, test=test_ds + )