diff --git a/nobrainer/dataset.py b/nobrainer/dataset.py index 368ac739..31d2686b 100644 --- a/nobrainer/dataset.py +++ b/nobrainer/dataset.py @@ -11,7 +11,7 @@ from .io import _is_gzipped, verify_features_labels from .tfrecord import _labels_all_scalar, parse_example_fn, write -from .volume import binarize, replace, standardize, to_blocks +from .volume import binarize, replace, to_blocks AUTOTUNE = tf.data.experimental.AUTOTUNE @@ -315,9 +315,9 @@ def map_labels(self, label_mapping=None): if self.n_classes == 1: self.map(lambda x, y: (x, tf.expand_dims(binarize(y), -1))) elif self.n_classes == 2: - self.map(lambda x, y: (x, tf.one_hot(binarize(y), n_classes))) + self.map(lambda x, y: (x, tf.one_hot(binarize(y), self.n_classes))) elif self.n_classes > 2: - self.map(lambda x, y: (x, tf.one_hot(y, n_classes))) + self.map(lambda x, y: (x, tf.one_hot(y, self.n_classes))) return self