diff --git a/sup3r/preprocessing/base.py b/sup3r/preprocessing/base.py index 8a82efe3d..e046438fa 100644 --- a/sup3r/preprocessing/base.py +++ b/sup3r/preprocessing/base.py @@ -386,6 +386,9 @@ def wrap(self, data): if data is None: return data + if hasattr(data, 'data'): + data = data.data + if is_type_of(data, Sup3rDataset): return data diff --git a/sup3r/preprocessing/batch_handlers/factory.py b/sup3r/preprocessing/batch_handlers/factory.py index 3a65b64bb..4037f2224 100644 --- a/sup3r/preprocessing/batch_handlers/factory.py +++ b/sup3r/preprocessing/batch_handlers/factory.py @@ -265,26 +265,26 @@ def init_samplers( """Initialize samplers from given data containers.""" train_samplers = [ self.SAMPLER( - data=c.data, + data=container, sample_shape=sample_shape, feature_sets=feature_sets, batch_size=batch_size, **sampler_kwargs, ) - for c in train_containers + for container in train_containers ] val_samplers = ( [] if val_containers is None else [ self.SAMPLER( - data=c.data, + data=container, sample_shape=sample_shape, feature_sets=feature_sets, batch_size=batch_size, **sampler_kwargs, ) - for c in val_containers + for container in val_containers ] ) return train_samplers, val_samplers diff --git a/sup3r/preprocessing/batch_queues/abstract.py b/sup3r/preprocessing/batch_queues/abstract.py index b3083e11e..f5b6b49b9 100644 --- a/sup3r/preprocessing/batch_queues/abstract.py +++ b/sup3r/preprocessing/batch_queues/abstract.py @@ -11,13 +11,13 @@ import threading import time from abc import ABC, abstractmethod -from collections import namedtuple from typing import TYPE_CHECKING, List, Optional, Union import dask import numpy as np import tensorflow as tf +from sup3r.preprocessing.base import DsetTuple from sup3r.preprocessing.collections.base import Collection from sup3r.utilities.utilities import RANDOM_GENERATOR, Timer @@ -32,7 +32,7 @@ class AbstractBatchQueue(Collection, ABC): generator and maintains a queue of batches in a dedicated thread so the training routine can proceed as soon as batches are available.""" - Batch = namedtuple('Batch', ['low_res', 'high_res']) + BATCH_MEMBERS = ('low_res', 'high_res') def __init__( self, @@ -46,6 +46,7 @@ def __init__( max_workers: int = 1, thread_name: str = 'training', mode: str = 'lazy', + verbose: bool = False ): """ Parameters @@ -77,6 +78,8 @@ def __init__( Loading mode. Default is 'lazy', which only loads data into memory as batches are queued. 'eager' will load all data into memory right away. + verbose : bool + Whether to log timing information for batch steps. """ msg = ( f'{self.__class__.__name__} requires a list of samplers. ' @@ -101,6 +104,7 @@ def __init__( 'smoothing_ignore': [], 'smoothing': None, } + self.verbose = verbose self.timer = Timer() self.preflight() @@ -174,7 +178,7 @@ def transform(self, samples, **kwargs): high res samples. For a dual dataset queue this will just include smoothing.""" - def post_proc(self, samples) -> Batch: + def post_proc(self, samples) -> DsetTuple: """Performs some post proc on dequeued samples before sending out for training. Post processing can include coarsening on high-res data (if :class:`Collection` consists of :class:`Sampler` objects and not @@ -182,13 +186,12 @@ def post_proc(self, samples) -> Batch: Returns ------- - Batch : namedtuple - namedtuple with `low_res` and `high_res` attributes. Could also - include additional members for integration with - ``DualSamplerWithObs`` + Batch : DsetTuple + namedtuple-like object with `low_res` and `high_res` attributes. + Could also include `obs` member. """ tsamps = self.transform(samples, **self.transform_kwargs) - return self.Batch(**dict(zip(self.Batch._fields, tsamps))) + return DsetTuple(**dict(zip(self.BATCH_MEMBERS, tsamps))) def start(self) -> None: """Start thread to keep sample queue full for batches.""" @@ -216,7 +219,7 @@ def __iter__(self): self.start() return self - def get_batch(self) -> Batch: + def get_batch(self) -> DsetTuple: """Get batch from queue or directly from a ``Sampler`` through ``sample_batch``.""" if ( @@ -274,10 +277,10 @@ def enqueue_batches(self) -> None: logger.debug(self.log_queue_info()) log_time = time.time() - def __next__(self) -> Batch: + def __next__(self) -> DsetTuple: """Dequeue batch samples, squeeze if for a spatial only model, perform some post-proc like smoothing, coarsening, etc, and then send out for - training as a namedtuple of low_res / high_res arrays. + training as a namedtuple-like object of low_res / high_res arrays. Returns ------- @@ -295,11 +298,12 @@ def __next__(self) -> Batch: batch = self.post_proc(samples) self.timer.stop() self._batch_count += 1 - logger.debug( - 'Batch step %s finished in %s.', - self._batch_count, - self.timer.elapsed_str, - ) + if self.verbose: + logger.debug( + 'Batch step %s finished in %s.', + self._batch_count, + self.timer.elapsed_str, + ) else: raise StopIteration return batch diff --git a/sup3r/preprocessing/batch_queues/conditional.py b/sup3r/preprocessing/batch_queues/conditional.py index 43d479b5f..488e70b2c 100644 --- a/sup3r/preprocessing/batch_queues/conditional.py +++ b/sup3r/preprocessing/batch_queues/conditional.py @@ -2,12 +2,12 @@ import logging from abc import abstractmethod -from collections import namedtuple from typing import TYPE_CHECKING, Dict, List, Optional, Union import numpy as np from sup3r.models.conditional import Sup3rCondMom +from sup3r.preprocessing.base import DsetTuple from sup3r.preprocessing.utilities import numpy_if_tensor from .base import SingleBatchQueue @@ -22,10 +22,6 @@ class ConditionalBatchQueue(SingleBatchQueue): """BatchQueue class for conditional moment estimation.""" - ConditionalBatch = namedtuple( - 'ConditionalBatch', ['low_res', 'high_res', 'output', 'mask'] - ) - def __init__( self, samplers: Union[List['Sampler'], List['DualSampler']], @@ -160,14 +156,14 @@ def post_proc(self, samples): Returns ------- - namedtuple - Named tuple with `low_res`, `high_res`, `mask`, and `output` - attributes + DsetTuple + Namedtuple-like object with `low_res`, `high_res`, `mask`, and + `output` attributes """ lr, hr = self.transform(samples, **self.transform_kwargs) mask = self.make_mask(high_res=hr) output = self.make_output(samples=(lr, hr)) - return self.ConditionalBatch( + return DsetTuple( low_res=lr, high_res=hr, output=output, mask=mask ) diff --git a/sup3r/preprocessing/batch_queues/dual.py b/sup3r/preprocessing/batch_queues/dual.py index 4b64c593d..9a89ce8bd 100644 --- a/sup3r/preprocessing/batch_queues/dual.py +++ b/sup3r/preprocessing/batch_queues/dual.py @@ -2,7 +2,6 @@ interface with models.""" import logging -from collections import namedtuple from scipy.ndimage import gaussian_filter @@ -21,7 +20,7 @@ def __init__(self, samplers, **kwargs): -------- :class:`~sup3r.preprocessing.batch_queues.abstract.AbstractBatchQueue` """ - self.Batch = namedtuple('Batch', samplers[0]._fields) + self.BATCH_MEMBERS = samplers[0]._fields super().__init__(samplers, **kwargs) self.check_enhancement_factors() diff --git a/tests/batch_queues/test_bq_general.py b/tests/batch_queues/test_bq_general.py index 52b2ffb25..b130aa7b7 100644 --- a/tests/batch_queues/test_bq_general.py +++ b/tests/batch_queues/test_bq_general.py @@ -12,6 +12,7 @@ DummyData, DummySampler, ) +from sup3r.utilities.utilities import Timer FEATURES = ['windspeed', 'winddirection'] @@ -53,6 +54,55 @@ def test_batch_queue(): batcher.stop() +def test_batch_queue_workers(): + """Check that using max_workers > 1 for a batch queue is faster than using + max_workers = 1.""" + + timer = Timer() + sample_shape = (10, 10, 20) + n_batches = 20 + batch_size = 10 + max_workers = 10 + n_epochs = 10 + chunk_shape = {'south_north': 20, 'west_east': 20, 'time': 40} + samplers = [ + DummySampler( + sample_shape, + data_shape=(100, 100, 1000), + batch_size=batch_size, + features=FEATURES, + chunk_shape=chunk_shape + ) + ] + batcher = SingleBatchQueue( + samplers=samplers, + n_batches=n_batches, + batch_size=batch_size, + max_workers=1, + ) + timer.start() + for _ in range(n_epochs): + _ = list(batcher) + timer.stop() + batcher.stop() + serial_time = timer.elapsed / (n_epochs * n_batches) + + batcher = SingleBatchQueue( + samplers=samplers, + n_batches=n_batches, + batch_size=batch_size, + max_workers=max_workers, + ) + timer.start() + for _ in range(n_epochs): + _ = list(batcher) + timer.stop() + batcher.stop() + parallel_time = timer.elapsed / (n_epochs * n_batches) + print(f'Parallel / Serial Time: {parallel_time} / {serial_time}') + assert parallel_time < serial_time + + def test_spatial_batch_queue(): """Smoke test for spatial batch queue. A batch queue returns batches for spatial models if the sample shapes have 1 for the time axis"""