diff --git a/pandas_schema/validation.py b/pandas_schema/validation.py index 5f7c763..9e5af9f 100644 --- a/pandas_schema/validation.py +++ b/pandas_schema/validation.py @@ -9,7 +9,7 @@ from . import column from .validation_warning import ValidationWarning from .errors import PanSchArgumentError -from pandas.api.types import is_categorical_dtype, is_numeric_dtype +from pandas.api.types import is_categorical_dtype, is_datetime64_any_dtype, is_numeric_dtype class _BaseValidation: @@ -85,8 +85,8 @@ def get_errors(self, series: pd.Series, column: 'column.Column'): simple_validation = ~self.validate(series) if column.allow_empty: # Failing results are those that are not empty, and fail the validation - # explicitly check to make sure the series isn't a category because issubdtype will FAIL if it is - if is_categorical_dtype(series) or is_numeric_dtype(series): + # explicitly check to make sure the series isn't a category/datetime because issubdtype will FAIL if it is + if is_categorical_dtype(series) or is_datetime64_any_dtype(series) or is_numeric_dtype(series): validated = ~series.isnull() & simple_validation else: validated = (series.str.len() > 0) & simple_validation diff --git a/test/test_validation.py b/test/test_validation.py index fc40100..dc4b4e9 100644 --- a/test/test_validation.py +++ b/test/test_validation.py @@ -688,3 +688,27 @@ def test_invalid_elements(self): errors = self.validator.get_errors(pd.Series(['aa', 'bb', 'd'], dtype='category'), Column('', allow_empty=True)) self.assertEqual(len(errors), 3) + +class GetErrorAllowEmptyDatetimeTests(ValidationTestBase): + """ + Tests for datetime valued columns where allow_empty=True + """ + + def setUp(self): + match_val = datetime.datetime(2020, 11, 1) + self.validator = CustomSeriesValidation(lambda s: s == match_val, 'did not match target date') + + def test_valid(self): + series = pd.Series(['2020-11-01'], dtype='datetime64[ns]') + errors = self.validator.get_errors(series, Column('', allow_empty=True)) + self.assertEqual(len(errors), 0) + + def test_valid_invalid(self): + series = pd.Series(['2020-11-01', '2025-01-01'], dtype='datetime64[ns]') + errors = self.validator.get_errors(series, Column('', allow_empty=True)) + self.assertEqual(len(errors), 1) + + def test_valid_invalid_empty(self): + series = pd.Series(['2020-11-01', '2025-01-01', pd.NaT, np.NaN], dtype='datetime64[ns]') + errors = self.validator.get_errors(series, Column('', allow_empty=True)) + self.assertEqual(len(errors), 1)