diff --git a/nobrainer/processing/generation.py b/nobrainer/processing/generation.py index bc0da104..51b46625 100644 --- a/nobrainer/processing/generation.py +++ b/nobrainer/processing/generation.py @@ -5,7 +5,7 @@ from .base import BaseEstimator from .. import losses -from ..dataset import get_dataset +from ..dataset import Dataset class ProgressiveGeneration(BaseEstimator): @@ -147,15 +147,17 @@ def _compile(): if batch_size % self.strategy.num_replicas_in_sync: raise ValueError("batch size must be a multiple of the number of GPUs") - dataset = get_dataset( + dataset = Dataset.from_tfrecords( file_pattern=info.get("file_pattern"), - batch_size=batch_size, num_parallel_calls=num_parallel_calls, volume_shape=(resolution, resolution, resolution), n_classes=1, - scalar_label=True, - normalizer=info.get("normalizer") or normalizer, + scalar_labels=True, ) + n_epochs = info.get("epochs") or epochs + dataset.batch(batch_size).normalize( + info.get("normalizer") or normalizer + ).repeat(n_epochs) with self.strategy.scope(): # grow the networks by one (2^x) resolution @@ -164,9 +166,7 @@ def _compile(): self.model_.discriminator.add_resolution() _compile() - steps_per_epoch = (info.get("epochs") or epochs) // info.get( - "batch_size" - ) + steps_per_epoch = n_epochs // info.get("batch_size") # save_best_only is set to False as it is an adversarial loss model_checkpoint_callback = tf.keras.callbacks.ModelCheckpoint( @@ -182,7 +182,7 @@ def _compile(): print("Transition phase") self.model_.fit( - dataset, + dataset.dataset, phase="transition", resolution=resolution, steps_per_epoch=steps_per_epoch, # necessary for repeat dataset @@ -191,7 +191,7 @@ def _compile(): print("Resolution phase") self.model_.fit( - dataset, + dataset.dataset, phase="resolution", resolution=resolution, steps_per_epoch=steps_per_epoch,