diff --git a/CHANGES.md b/CHANGES.md index 4ffd2b82d..b81b3eadd 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -20,6 +20,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - The behavior of method `net.get_params` changed to make it more consistent with sklearn: it will no longer return "learned" attributes like `module_`; therefore, functions like `sklearn.base.clone`, when called with a fitted net, will no longer return a fitted net but instead an uninitialized net; if you want a copy of a fitted net, use `copy.deepcopy` instead;`net.get_params` is used under the hood by many sklearn functions and classes, such as `GridSearchCV`, whose behavior may thus be affected by the change. (#521, #527) - Raise `FutureWarning` when using `CyclicLR` scheduler, because the default behavior has changed from taking a step every batch to taking a step every epoch. (#626) - Set train/validation on criterion if it's a PyTorch module (#621) +- Don't pass `y=None` to `NeuralNet.train_split` to enable the direct use of split functions without positional `y` in their signatures. This is useful when working with unsupervised data (#605). ### Fixed diff --git a/skorch/net.py b/skorch/net.py index aa6c9e1d9..ac82813c9 100644 --- a/skorch/net.py +++ b/skorch/net.py @@ -1210,12 +1210,16 @@ def get_split_datasets(self, X, y=None, **fit_params): """ dataset = self.get_dataset(X, y) - if self.train_split: - dataset_train, dataset_valid = self.train_split( - dataset, y, **fit_params) - else: - dataset_train, dataset_valid = dataset, None - return dataset_train, dataset_valid + if not self.train_split: + return dataset, None + + # After a change in (#646), + # `y` is no longer passed to `self.train_split` if it is `None`. + # To revert to the previous behavior, remove the following two lines: + if y is None: + return self.train_split(dataset, **fit_params) + + return self.train_split(dataset, y, **fit_params) def get_iterator(self, dataset, training=False): """Get an iterator that allows to loop over the batches of the diff --git a/skorch/tests/callbacks/test_scoring.py b/skorch/tests/callbacks/test_scoring.py index 5764f0dd3..ce8959695 100644 --- a/skorch/tests/callbacks/test_scoring.py +++ b/skorch/tests/callbacks/test_scoring.py @@ -401,13 +401,16 @@ def __init__(self, X, y): class MySkorchDataset(skorch.dataset.Dataset): pass - rawsplit = lambda ds, _: (ds, ds) + rawsplit = lambda ds: (ds, ds) cvsplit = CVSplit(2, random_state=0) + def split_ignore_y(ds, y): + return rawsplit(ds) + table = [ # Test a split where type(input) == type(output) is guaranteed - (data, rawsplit, np.ndarray, False), - (data, rawsplit, skorch.dataset.Dataset, True), + (data, split_ignore_y, np.ndarray, False), + (data, split_ignore_y, skorch.dataset.Dataset, True), ((MyTorchDataset(*data), None), rawsplit, MyTorchDataset, False), ((MyTorchDataset(*data), None), rawsplit, MyTorchDataset, True), ((MySkorchDataset(*data), None), rawsplit, np.ndarray, False), diff --git a/skorch/tests/test_net.py b/skorch/tests/test_net.py index 78ba9741b..6d07ed2ed 100644 --- a/skorch/tests/test_net.py +++ b/skorch/tests/test_net.py @@ -2490,6 +2490,48 @@ def initialize_module(self, *args, **kwargs): hidden_units = net.mymodule_.state_dict()['sequential.3.weight'].shape[1] assert hidden_units == 99 + @pytest.mark.parametrize("needs_y, train_split, raises", [ + (False, None, ExitStack()), # ExitStack = does not raise + (True, None, ExitStack()), + (False, "default", ExitStack()), # Default parameters for NeuralNet + (True, "default", ExitStack()), # Default parameters for NeuralNet + (False, lambda x: (x, x), ExitStack()), # Earlier this was not allowed + (True, lambda x, y: (x, x), ExitStack()), # Works for custom split + (True, lambda x: (x, x), pytest.raises(TypeError)), # Raises an error + ]) + def test_passes_y_to_train_split_when_not_none( + self, needs_y, train_split, raises): + from skorch.net import NeuralNet + from skorch.toy import MLPModule + + # By default, `train_split=CVSplit(5)` in the `NeuralNet` definition + kwargs = {} if train_split == 'default' else { + 'train_split': train_split} + + # Dummy loss that ignores y_true + class UnsupervisedLoss(torch.nn.NLLLoss): + def forward(self, y_pred, _): + return y_pred.mean() + + # Generate the dummy dataset + n_samples, n_features = 128, 10 + X = np.random.rand(n_samples, n_features).astype(np.float32) + y = np.random.binomial(n=1, p=0.5, size=n_samples) if needs_y else None + + # The `NeuralNetClassifier` or `NeuralNetRegressor` always require `y` + # Only `NeuralNet` can transfer `y=None` to `train_split` method. + net = NeuralNet( + MLPModule, # Any model, it's not important here + module__input_units=n_features, + max_epochs=2, # Run train loop twice to detect possible errors + criterion=UnsupervisedLoss, + **kwargs, + ) + + # Check if the code should fail or not + with raises: + net.fit(X, y) + class TestNetSparseInput: @pytest.fixture(scope='module') diff --git a/skorch/utils.py b/skorch/utils.py index e37d6de6a..1624a90d5 100644 --- a/skorch/utils.py +++ b/skorch/utils.py @@ -475,7 +475,7 @@ def get_step(self): return self.step -def _make_split(X, y, valid_ds, **kwargs): +def _make_split(X, valid_ds, **kwargs): """Used by ``predefined_split`` to allow for pickling""" return X, valid_ds