diff --git a/pyradigm/base.py b/pyradigm/base.py index 07e6f9c..1e67761 100644 --- a/pyradigm/base.py +++ b/pyradigm/base.py @@ -34,6 +34,10 @@ class InfiniteOrNaNValuesException(PyradigmException): """Custom exception to catch NaN or Inf values.""" +class InvalidFeatureNamesException(PyradigmException): + """Custom exception to catch bad feature names.""" + + class CompatibilityException(PyradigmException): """ Exception to indicate two datasets are not compatible in some way @@ -349,6 +353,37 @@ def _check_id(self, samplet_id): else: return samplet_id + def _check_feature_names(self, feature_names, nfeatures): + """ + Check feature names match what was previously used + + if nothing preivously used and nothing given: + set to strings like 0,1,...nfeats + + Returns + ------- + feature_names as numpy array or None + """ + # on the first sample we'll make names 0,...,nfeats + # but afterwward, dont spend time generating it. + # add_samplet wont change the names if None + if feature_names is None: + if self.num_samplets <= 0: + return self._str_names(nfeatures) + return None + + feature_names = np.array(feature_names) + if len(feature_names) != nfeatures: + raise InvalidFeatureNamesException( + "Length of supplied feature_names does not match features") + + if self._feature_names is not None and \ + not np.array_equal(self.feature_names, feature_names): + raise InvalidFeatureNamesException( + "Supplied feature_names do not match what was previously specified") + + return feature_names + def add_samplet(self, samplet_id, @@ -403,33 +438,18 @@ def add_samplet(self, features = self._check_features(features) target = self._check_target(target) - - if self.num_samplets <= 0: - self._data[samplet_id] = features - self._targets[samplet_id] = target - self._num_features = features.size if isinstance(features, - np.ndarray) else len( - features) - if feature_names is None: - self._feature_names = self._str_names(self.num_features) - else: - if self._num_features != features.size: - raise ValueError('dimensionality of this samplet ({}) ' - 'does not match existing samplets ({})' - ''.format(features.size, self._num_features)) - - self._data[samplet_id] = features - self._targets[samplet_id] = target - if feature_names is not None: - # if it was never set, allow it - # class gets here when adding the first samplet, - # after dataset was initialized with empty constructor - if self._feature_names is None: - self._feature_names = np.array(feature_names) - else: # if set already, ensure a match - if not np.array_equal(self.feature_names, np.array(feature_names)): - raise ValueError( - "supplied feature names do not match the existing names!") + nfeats = features.size if isinstance(features, np.ndarray) \ + else len(features) + feature_names = self._check_feature_names(feature_names, nfeats) + # TODO: attr should also be checked before _data is modified? + + self._data[samplet_id] = features + self._targets[samplet_id] = target + self._num_features = nfeats + self._targets[samplet_id] = target + # even if given featnames=None, will be 0...nfeats for first samplet + if feature_names is not None and self._feature_names is None: + self._feature_names = feature_names if attr_names is not None: if is_iterable_but_not_str(attr_names): diff --git a/pyradigm/tests/test_BaseDataset_common_behaviours.py b/pyradigm/tests/test_BaseDataset_common_behaviours.py index e0ec5dd..9cf3541 100644 --- a/pyradigm/tests/test_BaseDataset_common_behaviours.py +++ b/pyradigm/tests/test_BaseDataset_common_behaviours.py @@ -14,7 +14,8 @@ RegressionDataset as RegrDataset) from pyradigm.utils import make_random_ClfDataset from pyradigm.base import is_iterable_but_not_str, PyradigmException, \ - ConstantValuesException, InfiniteOrNaNValuesException, EmptyFeatureSetException + ConstantValuesException, InfiniteOrNaNValuesException, \ + InvalidFeatureNamesException, EmptyFeatureSetException from pytest import raises, warns import numpy as np import random @@ -253,7 +254,40 @@ def test_sanity_checks(): with raises(ConstantValuesException): const_ds.save(out_file) + +def forall_dataset_types(func): + "decorator runs func for both dataset types" + for cls_type in (RegrDataset, ClfDataset): + func(cls_type) + + +@forall_dataset_types +def test_feature_names_len_mismatch(cls_type): + "feature names should align to features" + clean_ds = cls_type() + feats2 = [1, 2] + # too long + with raises(InvalidFeatureNamesException): + clean_ds.add_samplet('b', feats2, 200, feature_names=['x', 'y', 'z']) + # too short + with raises(InvalidFeatureNamesException): + clean_ds.add_samplet('c', feats2, 100, feature_names=['x']) + + +@forall_dataset_types +def test_feature_names_change_order(cls_type): + """add samplet shouldn't change the order of features + TODO: this could be smarter and not throw an error + """ + clean_ds = cls_type() + clean_ds.add_samplet('a', [1, 2], 100, feature_names=['x', 'y']) + # too short + with raises(InvalidFeatureNamesException): + clean_ds.add_samplet('c', [4, 3], 200, + feature_names=['y', 'x']) + + test_attributes() # test_save_load() # test_sanity_checks() -test_nan_inf_values() \ No newline at end of file +test_nan_inf_values()