Skip to content

Commit

Permalink
removed namedtuple from Sup3rDataset to make Sup3rDataset picklable.
Browse files Browse the repository at this point in the history
  • Loading branch information
bnb32 committed Dec 26, 2024
1 parent 968f9b6 commit b8344ad
Show file tree
Hide file tree
Showing 5 changed files with 74 additions and 40 deletions.
14 changes: 9 additions & 5 deletions sup3r/models/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -684,9 +684,9 @@ def finish_epoch(
"""
self.log_loss_details(loss_details)
self._history.at[epoch, 'elapsed_time'] = time.time() - t0
for key, value in loss_details.items():
if key != 'n_obs':
self._history.at[epoch, key] = value
cols = [k for k in loss_details if k != 'n_obs']
entry = np.vstack([loss_details[k] for k in cols])
self._history.loc[epoch, cols] = entry.T

last_epoch = epoch == epochs[-1]
chp = checkpoint_int is not None and (epoch % checkpoint_int) == 0
Expand All @@ -710,8 +710,8 @@ def finish_epoch(
self.save(out_dir.format(epoch=epoch))

if extras is not None:
for k, v in extras.items():
self._history.at[epoch, k] = safe_cast(v)
entry = np.vstack([safe_cast(v) for v in extras.values()])
self._history.loc[epoch, list(extras)] = entry.T

return stop

Expand Down Expand Up @@ -744,6 +744,8 @@ def run_gradient_descent(
current loss weight values.
obs_data : tf.Tensor | None
Optional observation data to use in additional content loss term.
(n_observations, spatial_1, spatial_2, features)
(n_observations, spatial_1, spatial_2, temporal, features)
optimizer : tf.keras.optimizers.Optimizer
Optimizer class to use to update weights. This can be different if
you're training just the generator or one of the discriminator
Expand Down Expand Up @@ -1054,6 +1056,8 @@ def get_single_grad(
current loss weight values.
obs_data : tf.Tensor | None
Optional observation data to use in additional content loss term.
(n_observations, spatial_1, spatial_2, features)
(n_observations, spatial_1, spatial_2, temporal, features)
device_name : None | str
Optional tensorflow device name for GPU placement. Note that if a
GPU is available, variables will be placed on that GPU even if
Expand Down
25 changes: 23 additions & 2 deletions sup3r/preprocessing/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
import logging
import pprint
from abc import ABCMeta
from collections import namedtuple
from typing import Dict, Mapping, Tuple, Union
from warnings import warn

Expand Down Expand Up @@ -70,6 +69,28 @@ def __repr__(cls):
return f"<class '{cls.__module__}.{cls.__name__}'>"


class DsetTuple:
"""A simple class to mimic namedtuple behavior with dynamic attributes
while being serializable"""

def __init__(self, **kwargs):
self.__dict__.update(kwargs)

def __iter__(self):
return iter(self.__dict__.values())

def __getitem__(self, key):
if isinstance(key, int):
key = list(self.__dict__)[key]
return self.__dict__[key]

def __len__(self):
return len(self.__dict__)

def __repr__(self):
return f"DsetTuple({self.__dict__})"


class Sup3rDataset:
"""Interface for interacting with one or two ``xr.Dataset`` instances.
This is a wrapper around one or two ``Sup3rX`` objects so they work well
Expand Down Expand Up @@ -149,7 +170,7 @@ def __init__(
assert len(dset) == 1, msg
dsets[name] = dset._ds[0]

self._ds = namedtuple('Dataset', list(dsets))(**dsets)
self._ds = DsetTuple(**dsets)

def __iter__(self):
yield from self._ds
Expand Down
47 changes: 29 additions & 18 deletions sup3r/preprocessing/batch_queues/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@
import time
from abc import ABC, abstractmethod
from collections import namedtuple
from concurrent.futures import ThreadPoolExecutor
from typing import TYPE_CHECKING, List, Optional, Union

import dask
import numpy as np
import tensorflow as tf

Expand All @@ -36,8 +36,7 @@ class AbstractBatchQueue(Collection, ABC):

def __init__(
self,
samplers: Union[
List['Sampler'], List['DualSampler']],
samplers: Union[List['Sampler'], List['DualSampler']],
batch_size: int = 16,
n_batches: int = 64,
s_enhance: int = 1,
Expand Down Expand Up @@ -237,23 +236,40 @@ def running(self):
and not self.queue.is_closed()
)

def sample_batches(self, n_batches) -> None:
"""Sample N batches from samplers. Returns N batches which are then
used to fill the queue."""
if n_batches == 1:
return [self.sample_batch()]

tasks = [dask.delayed(self.sample_batch)() for _ in range(n_batches)]
logger.debug('Added %s sample_batch futures to %s queue.',
n_batches,
self._thread_name)

if self.max_workers == 1:
batches = dask.compute(*tasks, scheduler='single-threaded')
else:
batches = dask.compute(
*tasks, scheduler='threads', num_workers=self.max_workers)
return batches

def enqueue_batches(self) -> None:
"""Callback function for queue thread. While training, the queue is
checked for empty spots and filled. In the training thread, batches are
removed from the queue."""
log_time = time.time()
while self.running:
needed = self.queue_cap - self.queue.size().numpy()
if needed == 1 or self.max_workers == 1:
self.enqueue_batch()
elif needed > 0:
with ThreadPoolExecutor(self.max_workers) as exe:
_ = [exe.submit(self.enqueue_batch) for _ in range(needed)]
logger.debug(
'Added %s enqueue futures to %s queue.',
needed,
self._thread_name,
)

# no point in getting more than one batch at a time if
# max_workers == 1
needed = 1 if needed > 0 and self.max_workers == 1 else needed

if needed > 0:
for batch in self.sample_batches(n_batches=needed):
self.queue.enqueue(batch)

if time.time() > log_time + 10:
logger.debug(self.log_queue_info())
log_time = time.time()
Expand Down Expand Up @@ -317,11 +333,6 @@ def log_queue_info(self):
self.queue_cap,
)

def enqueue_batch(self):
"""Build batch and send to queue."""
if self.running and self.queue.size().numpy() < self.queue_cap:
self.queue.enqueue(self.sample_batch())

@property
def lr_shape(self):
"""Shape of low resolution sample in a low-res / high-res pair. (e.g.
Expand Down
1 change: 1 addition & 0 deletions sup3r/preprocessing/cachers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,6 +330,7 @@ def write_h5(
]

if Dimension.TIME in data:
# int64 used explicity to avoid incorrect encoding as int32
data[Dimension.TIME] = data[Dimension.TIME].astype('int64')

for dset in [*coord_names, *features]:
Expand Down
27 changes: 12 additions & 15 deletions sup3r/utilities/pytest/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,9 +107,17 @@ class DummySampler(Sampler):
"""Dummy container with random data."""

def __init__(
self, sample_shape, data_shape, features, batch_size, feature_sets=None
self,
sample_shape,
data_shape,
features,
batch_size,
feature_sets=None,
chunk_shape=None,
):
data = make_fake_dset(data_shape, features=features)
if chunk_shape is not None:
data = data.chunk(chunk_shape)
super().__init__(
Sup3rDataset(high_res=data),
sample_shape,
Expand Down Expand Up @@ -314,10 +322,7 @@ def make_collect_chunks(td):
out_files = []
for t, slice_hr in enumerate(t_slices_hr):
for s, (s1_hr, s2_hr) in enumerate(product(s_slices_hr, s_slices_hr)):
out_file = out_pattern.format(
t=str(t).zfill(6),
s=str(s).zfill(6)
)
out_file = out_pattern.format(t=str(t).zfill(6), s=str(s).zfill(6))
out_files.append(out_file)
OutputHandlerH5._write_output(
data[s1_hr, s2_hr, slice_hr, :],
Expand All @@ -330,15 +335,7 @@ def make_collect_chunks(td):
gids=gids[s1_hr, s2_hr],
)

return (
out_files,
data,
ws_true,
wd_true,
features,
hr_lat_lon,
hr_times
)
return (out_files, data, ws_true, wd_true, features, hr_lat_lon, hr_times)


def make_fake_h5_chunks(td):
Expand Down Expand Up @@ -436,7 +433,7 @@ def make_fake_h5_chunks(td):
s_slices_lr,
s_slices_hr,
low_res_lat_lon,
low_res_times
low_res_times,
)


Expand Down

0 comments on commit b8344ad

Please sign in to comment.