Skip to content

Commit

Permalink
parallel batch queue test added.
Browse files Browse the repository at this point in the history
  • Loading branch information
bnb32 committed Dec 27, 2024
1 parent b8344ad commit 3165ce0
Show file tree
Hide file tree
Showing 6 changed files with 83 additions and 31 deletions.
3 changes: 3 additions & 0 deletions sup3r/preprocessing/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,6 +386,9 @@ def wrap(self, data):
if data is None:
return data

if hasattr(data, 'data'):
data = data.data

if is_type_of(data, Sup3rDataset):
return data

Expand Down
8 changes: 4 additions & 4 deletions sup3r/preprocessing/batch_handlers/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,26 +265,26 @@ def init_samplers(
"""Initialize samplers from given data containers."""
train_samplers = [
self.SAMPLER(
data=c.data,
data=container,
sample_shape=sample_shape,
feature_sets=feature_sets,
batch_size=batch_size,
**sampler_kwargs,
)
for c in train_containers
for container in train_containers
]
val_samplers = (
[]
if val_containers is None
else [
self.SAMPLER(
data=c.data,
data=container,
sample_shape=sample_shape,
feature_sets=feature_sets,
batch_size=batch_size,
**sampler_kwargs,
)
for c in val_containers
for container in val_containers
]
)
return train_samplers, val_samplers
Expand Down
36 changes: 20 additions & 16 deletions sup3r/preprocessing/batch_queues/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,13 @@
import threading
import time
from abc import ABC, abstractmethod
from collections import namedtuple
from typing import TYPE_CHECKING, List, Optional, Union

import dask
import numpy as np
import tensorflow as tf

from sup3r.preprocessing.base import DsetTuple
from sup3r.preprocessing.collections.base import Collection
from sup3r.utilities.utilities import RANDOM_GENERATOR, Timer

Expand All @@ -32,7 +32,7 @@ class AbstractBatchQueue(Collection, ABC):
generator and maintains a queue of batches in a dedicated thread so the
training routine can proceed as soon as batches are available."""

Batch = namedtuple('Batch', ['low_res', 'high_res'])
BATCH_MEMBERS = ('low_res', 'high_res')

def __init__(
self,
Expand All @@ -46,6 +46,7 @@ def __init__(
max_workers: int = 1,
thread_name: str = 'training',
mode: str = 'lazy',
verbose: bool = False
):
"""
Parameters
Expand Down Expand Up @@ -77,6 +78,8 @@ def __init__(
Loading mode. Default is 'lazy', which only loads data into memory
as batches are queued. 'eager' will load all data into memory right
away.
verbose : bool
Whether to log timing information for batch steps.
"""
msg = (
f'{self.__class__.__name__} requires a list of samplers. '
Expand All @@ -101,6 +104,7 @@ def __init__(
'smoothing_ignore': [],
'smoothing': None,
}
self.verbose = verbose
self.timer = Timer()
self.preflight()

Expand Down Expand Up @@ -174,21 +178,20 @@ def transform(self, samples, **kwargs):
high res samples. For a dual dataset queue this will just include
smoothing."""

def post_proc(self, samples) -> Batch:
def post_proc(self, samples) -> DsetTuple:
"""Performs some post proc on dequeued samples before sending out for
training. Post processing can include coarsening on high-res data (if
:class:`Collection` consists of :class:`Sampler` objects and not
:class:`DualSampler` objects), smoothing, etc
Returns
-------
Batch : namedtuple
namedtuple with `low_res` and `high_res` attributes. Could also
include additional members for integration with
``DualSamplerWithObs``
Batch : DsetTuple
namedtuple-like object with `low_res` and `high_res` attributes.
Could also include `obs` member.
"""
tsamps = self.transform(samples, **self.transform_kwargs)
return self.Batch(**dict(zip(self.Batch._fields, tsamps)))
return DsetTuple(**dict(zip(self.BATCH_MEMBERS, tsamps)))

def start(self) -> None:
"""Start thread to keep sample queue full for batches."""
Expand Down Expand Up @@ -216,7 +219,7 @@ def __iter__(self):
self.start()
return self

def get_batch(self) -> Batch:
def get_batch(self) -> DsetTuple:
"""Get batch from queue or directly from a ``Sampler`` through
``sample_batch``."""
if (
Expand Down Expand Up @@ -274,10 +277,10 @@ def enqueue_batches(self) -> None:
logger.debug(self.log_queue_info())
log_time = time.time()

def __next__(self) -> Batch:
def __next__(self) -> DsetTuple:
"""Dequeue batch samples, squeeze if for a spatial only model, perform
some post-proc like smoothing, coarsening, etc, and then send out for
training as a namedtuple of low_res / high_res arrays.
training as a namedtuple-like object of low_res / high_res arrays.
Returns
-------
Expand All @@ -295,11 +298,12 @@ def __next__(self) -> Batch:
batch = self.post_proc(samples)
self.timer.stop()
self._batch_count += 1
logger.debug(
'Batch step %s finished in %s.',
self._batch_count,
self.timer.elapsed_str,
)
if self.verbose:
logger.debug(
'Batch step %s finished in %s.',
self._batch_count,
self.timer.elapsed_str,
)
else:
raise StopIteration
return batch
Expand Down
14 changes: 5 additions & 9 deletions sup3r/preprocessing/batch_queues/conditional.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,12 @@

import logging
from abc import abstractmethod
from collections import namedtuple
from typing import TYPE_CHECKING, Dict, List, Optional, Union

import numpy as np

from sup3r.models.conditional import Sup3rCondMom
from sup3r.preprocessing.base import DsetTuple
from sup3r.preprocessing.utilities import numpy_if_tensor

from .base import SingleBatchQueue
Expand All @@ -22,10 +22,6 @@
class ConditionalBatchQueue(SingleBatchQueue):
"""BatchQueue class for conditional moment estimation."""

ConditionalBatch = namedtuple(
'ConditionalBatch', ['low_res', 'high_res', 'output', 'mask']
)

def __init__(
self,
samplers: Union[List['Sampler'], List['DualSampler']],
Expand Down Expand Up @@ -160,14 +156,14 @@ def post_proc(self, samples):
Returns
-------
namedtuple
Named tuple with `low_res`, `high_res`, `mask`, and `output`
attributes
DsetTuple
Namedtuple-like object with `low_res`, `high_res`, `mask`, and
`output` attributes
"""
lr, hr = self.transform(samples, **self.transform_kwargs)
mask = self.make_mask(high_res=hr)
output = self.make_output(samples=(lr, hr))
return self.ConditionalBatch(
return DsetTuple(
low_res=lr, high_res=hr, output=output, mask=mask
)

Expand Down
3 changes: 1 addition & 2 deletions sup3r/preprocessing/batch_queues/dual.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
interface with models."""

import logging
from collections import namedtuple

from scipy.ndimage import gaussian_filter

Expand All @@ -21,7 +20,7 @@ def __init__(self, samplers, **kwargs):
--------
:class:`~sup3r.preprocessing.batch_queues.abstract.AbstractBatchQueue`
"""
self.Batch = namedtuple('Batch', samplers[0]._fields)
self.BATCH_MEMBERS = samplers[0]._fields
super().__init__(samplers, **kwargs)
self.check_enhancement_factors()

Expand Down
50 changes: 50 additions & 0 deletions tests/batch_queues/test_bq_general.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
DummyData,
DummySampler,
)
from sup3r.utilities.utilities import Timer

FEATURES = ['windspeed', 'winddirection']

Expand Down Expand Up @@ -53,6 +54,55 @@ def test_batch_queue():
batcher.stop()


def test_batch_queue_workers():
"""Check that using max_workers > 1 for a batch queue is faster than using
max_workers = 1."""

timer = Timer()
sample_shape = (10, 10, 20)
n_batches = 20
batch_size = 10
max_workers = 10
n_epochs = 10
chunk_shape = {'south_north': 20, 'west_east': 20, 'time': 40}
samplers = [
DummySampler(
sample_shape,
data_shape=(100, 100, 1000),
batch_size=batch_size,
features=FEATURES,
chunk_shape=chunk_shape
)
]
batcher = SingleBatchQueue(
samplers=samplers,
n_batches=n_batches,
batch_size=batch_size,
max_workers=1,
)
timer.start()
for _ in range(n_epochs):
_ = list(batcher)
timer.stop()
batcher.stop()
serial_time = timer.elapsed / (n_epochs * n_batches)

batcher = SingleBatchQueue(
samplers=samplers,
n_batches=n_batches,
batch_size=batch_size,
max_workers=max_workers,
)
timer.start()
for _ in range(n_epochs):
_ = list(batcher)
timer.stop()
batcher.stop()
parallel_time = timer.elapsed / (n_epochs * n_batches)
print(f'Parallel / Serial Time: {parallel_time} / {serial_time}')
assert parallel_time < serial_time


def test_spatial_batch_queue():
"""Smoke test for spatial batch queue. A batch queue returns batches for
spatial models if the sample shapes have 1 for the time axis"""
Expand Down

0 comments on commit 3165ce0

Please sign in to comment.