Skip to content

Commit

Permalink
hotfix: 24.5.1.2
Browse files Browse the repository at this point in the history
  • Loading branch information
paxcema committed May 15, 2024
1 parent d485aae commit debce12
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 10 deletions.
2 changes: 1 addition & 1 deletion dataprep_ml/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from dataprep_ml.base import StatisticalAnalysis, DataAnalysis

__version__ = '24.5.1.1'
__version__ = '24.5.1.2'
__name__ = "dataprep_ml"


Expand Down
27 changes: 19 additions & 8 deletions dataprep_ml/splitters.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,14 +58,25 @@ def splitter(
train, dev, test = simple_split(data, pct_train, pct_dev, pct_test)

# Final assertions for time series
window = tss.get('window', 1) if tss.get('window', 1) else 1
horizon = tss.get('horizon', 1) if tss.get('horizon', 1) else 1

if min(len(train), len(dev)) < window:
raise Exception(f"Dataset size is too small for the specified window size ({window})")

if min(len(train), len(dev), len(test)) < horizon:
raise Exception(f"Dataset size is too small for the specified horizon size ({horizon})")
if tss.get('is_timeseries', False) not in (None, False):
window = tss.get('window', 1) if tss.get('window', 1) else 1
horizon = tss.get('horizon', 1) if tss.get('horizon', 1) else 1

if all([pct_train, pct_dev, pct_test]) > 0.0:
check_partitions = [train, dev, test]
elif all([pct_train, pct_test]) > 0.0:
check_partitions = [train, test]
elif all([pct_train, pct_dev]) > 0.0:
check_partitions = [train, dev]
else:
check_partitions = [train]
partition_lengths = [len(partition) for partition in check_partitions]

if min(partition_lengths) < window:
raise Exception(f"Dataset too small for the specified window size ({window}). Partition length: {partition_lengths}") # noqa

if min(partition_lengths) < horizon:
raise Exception(f"Dataset too small for the specified horizon size ({horizon}). Partition length: {partition_lengths}") # noqa

return {"train": train, "test": test, "dev": dev, "stratified_on": stratify_on}

Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "dataprep-ml"
version = "24.5.1.1"
version = "24.5.1.2"
description = "Automated dataframe analysis for Machine Learning pipelines."
authors = ["MindsDB Inc. <[email protected]>"]
license = "GPL-3.0"
Expand Down

0 comments on commit debce12

Please sign in to comment.