diff --git a/dataprep_ml/__init__.py b/dataprep_ml/__init__.py index 0b4cf02..d8805b4 100644 --- a/dataprep_ml/__init__.py +++ b/dataprep_ml/__init__.py @@ -1,6 +1,6 @@ from dataprep_ml.base import StatisticalAnalysis, DataAnalysis -__version__ = '24.5.1.1' +__version__ = '24.5.1.2' __name__ = "dataprep_ml" diff --git a/dataprep_ml/splitters.py b/dataprep_ml/splitters.py index 8495d4c..708535e 100644 --- a/dataprep_ml/splitters.py +++ b/dataprep_ml/splitters.py @@ -58,14 +58,25 @@ def splitter( train, dev, test = simple_split(data, pct_train, pct_dev, pct_test) # Final assertions for time series - window = tss.get('window', 1) if tss.get('window', 1) else 1 - horizon = tss.get('horizon', 1) if tss.get('horizon', 1) else 1 - - if min(len(train), len(dev)) < window: - raise Exception(f"Dataset size is too small for the specified window size ({window})") - - if min(len(train), len(dev), len(test)) < horizon: - raise Exception(f"Dataset size is too small for the specified horizon size ({horizon})") + if tss.get('is_timeseries', False) not in (None, False): + window = tss.get('window', 1) if tss.get('window', 1) else 1 + horizon = tss.get('horizon', 1) if tss.get('horizon', 1) else 1 + + if all([pct_train, pct_dev, pct_test]) > 0.0: + check_partitions = [train, dev, test] + elif all([pct_train, pct_test]) > 0.0: + check_partitions = [train, test] + elif all([pct_train, pct_dev]) > 0.0: + check_partitions = [train, dev] + else: + check_partitions = [train] + partition_lengths = [len(partition) for partition in check_partitions] + + if min(partition_lengths) < window: + raise Exception(f"Dataset too small for the specified window size ({window}). Partition length: {partition_lengths}") # noqa + + if min(partition_lengths) < horizon: + raise Exception(f"Dataset too small for the specified horizon size ({horizon}). Partition length: {partition_lengths}") # noqa return {"train": train, "test": test, "dev": dev, "stratified_on": stratify_on} diff --git a/pyproject.toml b/pyproject.toml index 95c4e39..c432a7f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "dataprep-ml" -version = "24.5.1.1" +version = "24.5.1.2" description = "Automated dataframe analysis for Machine Learning pipelines." authors = ["MindsDB Inc. "] license = "GPL-3.0"