From b8344adb495865d124ac172ec988a260b5f30abd Mon Sep 17 00:00:00 2001 From: bnb32 Date: Thu, 26 Dec 2024 14:22:32 -0800 Subject: [PATCH] removed namedtuple from Sup3rDataset to make Sup3rDataset picklable. --- sup3r/models/abstract.py | 14 +++--- sup3r/preprocessing/base.py | 25 ++++++++++- sup3r/preprocessing/batch_queues/abstract.py | 47 ++++++++++++-------- sup3r/preprocessing/cachers/base.py | 1 + sup3r/utilities/pytest/helpers.py | 27 +++++------ 5 files changed, 74 insertions(+), 40 deletions(-) diff --git a/sup3r/models/abstract.py b/sup3r/models/abstract.py index 17836deb2..506f85b38 100644 --- a/sup3r/models/abstract.py +++ b/sup3r/models/abstract.py @@ -684,9 +684,9 @@ def finish_epoch( """ self.log_loss_details(loss_details) self._history.at[epoch, 'elapsed_time'] = time.time() - t0 - for key, value in loss_details.items(): - if key != 'n_obs': - self._history.at[epoch, key] = value + cols = [k for k in loss_details if k != 'n_obs'] + entry = np.vstack([loss_details[k] for k in cols]) + self._history.loc[epoch, cols] = entry.T last_epoch = epoch == epochs[-1] chp = checkpoint_int is not None and (epoch % checkpoint_int) == 0 @@ -710,8 +710,8 @@ def finish_epoch( self.save(out_dir.format(epoch=epoch)) if extras is not None: - for k, v in extras.items(): - self._history.at[epoch, k] = safe_cast(v) + entry = np.vstack([safe_cast(v) for v in extras.values()]) + self._history.loc[epoch, list(extras)] = entry.T return stop @@ -744,6 +744,8 @@ def run_gradient_descent( current loss weight values. obs_data : tf.Tensor | None Optional observation data to use in additional content loss term. + (n_observations, spatial_1, spatial_2, features) + (n_observations, spatial_1, spatial_2, temporal, features) optimizer : tf.keras.optimizers.Optimizer Optimizer class to use to update weights. This can be different if you're training just the generator or one of the discriminator @@ -1054,6 +1056,8 @@ def get_single_grad( current loss weight values. obs_data : tf.Tensor | None Optional observation data to use in additional content loss term. + (n_observations, spatial_1, spatial_2, features) + (n_observations, spatial_1, spatial_2, temporal, features) device_name : None | str Optional tensorflow device name for GPU placement. Note that if a GPU is available, variables will be placed on that GPU even if diff --git a/sup3r/preprocessing/base.py b/sup3r/preprocessing/base.py index a6e9074e9..8a82efe3d 100644 --- a/sup3r/preprocessing/base.py +++ b/sup3r/preprocessing/base.py @@ -11,7 +11,6 @@ import logging import pprint from abc import ABCMeta -from collections import namedtuple from typing import Dict, Mapping, Tuple, Union from warnings import warn @@ -70,6 +69,28 @@ def __repr__(cls): return f"" +class DsetTuple: + """A simple class to mimic namedtuple behavior with dynamic attributes + while being serializable""" + + def __init__(self, **kwargs): + self.__dict__.update(kwargs) + + def __iter__(self): + return iter(self.__dict__.values()) + + def __getitem__(self, key): + if isinstance(key, int): + key = list(self.__dict__)[key] + return self.__dict__[key] + + def __len__(self): + return len(self.__dict__) + + def __repr__(self): + return f"DsetTuple({self.__dict__})" + + class Sup3rDataset: """Interface for interacting with one or two ``xr.Dataset`` instances. This is a wrapper around one or two ``Sup3rX`` objects so they work well @@ -149,7 +170,7 @@ def __init__( assert len(dset) == 1, msg dsets[name] = dset._ds[0] - self._ds = namedtuple('Dataset', list(dsets))(**dsets) + self._ds = DsetTuple(**dsets) def __iter__(self): yield from self._ds diff --git a/sup3r/preprocessing/batch_queues/abstract.py b/sup3r/preprocessing/batch_queues/abstract.py index a9423fe7e..b3083e11e 100644 --- a/sup3r/preprocessing/batch_queues/abstract.py +++ b/sup3r/preprocessing/batch_queues/abstract.py @@ -12,9 +12,9 @@ import time from abc import ABC, abstractmethod from collections import namedtuple -from concurrent.futures import ThreadPoolExecutor from typing import TYPE_CHECKING, List, Optional, Union +import dask import numpy as np import tensorflow as tf @@ -36,8 +36,7 @@ class AbstractBatchQueue(Collection, ABC): def __init__( self, - samplers: Union[ - List['Sampler'], List['DualSampler']], + samplers: Union[List['Sampler'], List['DualSampler']], batch_size: int = 16, n_batches: int = 64, s_enhance: int = 1, @@ -237,6 +236,24 @@ def running(self): and not self.queue.is_closed() ) + def sample_batches(self, n_batches) -> None: + """Sample N batches from samplers. Returns N batches which are then + used to fill the queue.""" + if n_batches == 1: + return [self.sample_batch()] + + tasks = [dask.delayed(self.sample_batch)() for _ in range(n_batches)] + logger.debug('Added %s sample_batch futures to %s queue.', + n_batches, + self._thread_name) + + if self.max_workers == 1: + batches = dask.compute(*tasks, scheduler='single-threaded') + else: + batches = dask.compute( + *tasks, scheduler='threads', num_workers=self.max_workers) + return batches + def enqueue_batches(self) -> None: """Callback function for queue thread. While training, the queue is checked for empty spots and filled. In the training thread, batches are @@ -244,16 +261,15 @@ def enqueue_batches(self) -> None: log_time = time.time() while self.running: needed = self.queue_cap - self.queue.size().numpy() - if needed == 1 or self.max_workers == 1: - self.enqueue_batch() - elif needed > 0: - with ThreadPoolExecutor(self.max_workers) as exe: - _ = [exe.submit(self.enqueue_batch) for _ in range(needed)] - logger.debug( - 'Added %s enqueue futures to %s queue.', - needed, - self._thread_name, - ) + + # no point in getting more than one batch at a time if + # max_workers == 1 + needed = 1 if needed > 0 and self.max_workers == 1 else needed + + if needed > 0: + for batch in self.sample_batches(n_batches=needed): + self.queue.enqueue(batch) + if time.time() > log_time + 10: logger.debug(self.log_queue_info()) log_time = time.time() @@ -317,11 +333,6 @@ def log_queue_info(self): self.queue_cap, ) - def enqueue_batch(self): - """Build batch and send to queue.""" - if self.running and self.queue.size().numpy() < self.queue_cap: - self.queue.enqueue(self.sample_batch()) - @property def lr_shape(self): """Shape of low resolution sample in a low-res / high-res pair. (e.g. diff --git a/sup3r/preprocessing/cachers/base.py b/sup3r/preprocessing/cachers/base.py index 4df526039..82a4bd5c1 100644 --- a/sup3r/preprocessing/cachers/base.py +++ b/sup3r/preprocessing/cachers/base.py @@ -330,6 +330,7 @@ def write_h5( ] if Dimension.TIME in data: + # int64 used explicity to avoid incorrect encoding as int32 data[Dimension.TIME] = data[Dimension.TIME].astype('int64') for dset in [*coord_names, *features]: diff --git a/sup3r/utilities/pytest/helpers.py b/sup3r/utilities/pytest/helpers.py index 3f9b320a5..bb9912c9f 100644 --- a/sup3r/utilities/pytest/helpers.py +++ b/sup3r/utilities/pytest/helpers.py @@ -107,9 +107,17 @@ class DummySampler(Sampler): """Dummy container with random data.""" def __init__( - self, sample_shape, data_shape, features, batch_size, feature_sets=None + self, + sample_shape, + data_shape, + features, + batch_size, + feature_sets=None, + chunk_shape=None, ): data = make_fake_dset(data_shape, features=features) + if chunk_shape is not None: + data = data.chunk(chunk_shape) super().__init__( Sup3rDataset(high_res=data), sample_shape, @@ -314,10 +322,7 @@ def make_collect_chunks(td): out_files = [] for t, slice_hr in enumerate(t_slices_hr): for s, (s1_hr, s2_hr) in enumerate(product(s_slices_hr, s_slices_hr)): - out_file = out_pattern.format( - t=str(t).zfill(6), - s=str(s).zfill(6) - ) + out_file = out_pattern.format(t=str(t).zfill(6), s=str(s).zfill(6)) out_files.append(out_file) OutputHandlerH5._write_output( data[s1_hr, s2_hr, slice_hr, :], @@ -330,15 +335,7 @@ def make_collect_chunks(td): gids=gids[s1_hr, s2_hr], ) - return ( - out_files, - data, - ws_true, - wd_true, - features, - hr_lat_lon, - hr_times - ) + return (out_files, data, ws_true, wd_true, features, hr_lat_lon, hr_times) def make_fake_h5_chunks(td): @@ -436,7 +433,7 @@ def make_fake_h5_chunks(td): s_slices_lr, s_slices_hr, low_res_lat_lon, - low_res_times + low_res_times, )