Skip to content

Commit

Permalink
examples added to DataHandler doc string. Some instructions on sup3rw…
Browse files Browse the repository at this point in the history
…ind model training added to examples README.
  • Loading branch information
bnb32 committed Dec 23, 2024
1 parent 92aa284 commit 516f72a
Show file tree
Hide file tree
Showing 9 changed files with 86 additions and 88 deletions.
42 changes: 39 additions & 3 deletions examples/sup3rwind/README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
Sup3rWind Examples
###################

Super-Resolution for Renewable Energy Resource Data with Wind from Reanalysis Data (Sup3rWind) is one application of the sup3r software. In this work, we train generative models to create high-resolution (2km 5-minute) wind data based on coarse (30km hourly) ERA5 data. The generative models and high-resolution output data is publicly available via the `Open Energy Data Initiative (OEDI) <https://data.openei.org/s3_viewer?bucket=nrel-pds-wtk&prefix=sup3rwind%2F>`__ and via HSDS at the bucket ``nrel-pds-hsds`` and path ``/nrel/wtk/sup3rwind``. This data covers recent historical time periods for an expanding selection of countries.
Super-Resolution for Renewable Energy Resource Data with Wind from Reanalysis Data (Sup3rWind) is one application of the sup3r software. In this work, we train generative models to create high-resolution (2km 5-minute) wind data based on coarse (30km hourly) ERA5 data. The generative models, high-resolution output data, and training data is publicly available via the `Open Energy Data Initiative (OEDI) <https://data.openei.org/s3_viewer?bucket=nrel-pds-wtk&prefix=sup3rwind%2F>`__ and via HSDS at the bucket ``nrel-pds-hsds`` and path ``/nrel/wtk/sup3rwind``. This data covers recent historical time periods for an expanding selection of countries.

Sup3rWind Data Access
----------------------
Expand All @@ -11,8 +11,8 @@ The Sup3rWind data and models are publicly available in a public AWS S3 bucket.

The Sup3rWind data is also loaded into `HSDS <https://www.hdfgroup.org/solutions/highly-scalable-data-service-hsds/>`__ so that you may stream the data via the `NREL developer API <https://developer.nrel.gov/signup/>`__ or your own HSDS server. This is the best option if you're not going to want a full annual dataset. See these `rex instructions <https://nrel.github.io/rex/misc/examples.hsds.html>`__ for more details on how to access this data with HSDS and rex.

Example Sup3rWind Data Usage
-----------------------------
Sup3rWind Data Usage
---------------------

Sup3rWind data can be used in generally the same way as `Sup3rCC <https://nrel.github.io/sup3r/examples/sup3rcc.html>`__ data, with the condition that Sup3rWind includes only wind data and ancillary variables for modeling wind energy generation. Refer to the Sup3rCC `example notebook <https://github.com/NREL/sup3r/tree/main/examples/sup3rcc/using_the_data.ipynb>`__ for usage patterns.

Expand All @@ -32,6 +32,42 @@ The process for running the Sup3rWind models is much the same as for `Sup3rCC <h
#. If you're running on a slurm cluster, this will kick off a number of jobs that you can see with the ``squeue`` command. If you're running locally, your terminal should now be running the Sup3rWind models. The software will create a ``./logs/`` directory in which you can monitor the progress of your jobs.
#. The ``sup3r-pipeline`` is designed to run several modules in serial, with each module running multiple chunks in parallel. Once the first module (forward-pass) finishes, you'll want to run ``python -m sup3r.cli -c config_pipeline.json pipeline`` again. This will clean up status files and kick off the next step in the pipeline (if the current step was successful).

Training from scratch
---------------------

To train Sup3rWind models from scratch use the public training `data <https://data.openei.org/s3_viewer?bucket=nrel-pds-wtk&prefix=sup3rwind%2Ftraining_data%2F>`. This data is for training the spatial enhancement models only. The 2024-01 `models <https://data.openei.org/s3_viewer?bucket=nrel-pds-wtk&prefix=sup3rwind%2Fmodels%2Fsup3rwind_models_202401%2F>`__ perform spatial enhancement in two steps, 3x from ERA5 to coarsened WTK and 5x from coarsened WTK to uncoarsened WTK. The currently used approach performs spatial enhancement in a single 15x step.

#. For a given year and training domain, initialize low-resolution and high-resolution data handlers and wrap these in a dual rasterizer object. Do this for as many years and training regions as desired, and use these containers to initialize a batch handler. To train models for 3x spatial enhancement use ``hr_spatial_coarsen=5`` in the ``hr_dh``. To train models for 15x (the currently used approach) ``hr_spatial_coarsen=1``. (Refer to tests and docs for information on additional arguments, denoted by the ellipses).

.. code-block:: python
from sup3r.preprocessing import DataHandler, DualBatchHandler, DualRasterizer
containers = []
for tdir in training_dirs:
lr_dh = DataHandler(f"{tdir}/lr_*.h5", ...)
hr_dh = DataHandler(f"{tdir}/hr_*.h5", hr_spatial_coarsen=...)
container = DualRasterizer({'low_res': lr_dh, 'high_res': hr_dh}, ...)
containers.append(container)
bh = DualBatchHandler(train_containers=containers, ...)
#. To train a 5x model use the ``hr_*.h5`` files for both the ``lr_dh`` and the ``hr_dh``. Use ``hr_spatial_coarsen=3`` in the ``lr_dh`` and ``hr_spatial_coarsen=1`` in the ``hr_dh``.

.. code-block:: python
for tdir in training_dirs:
lr_dh = DataHandler(f"{tdir}/hr_*.h5", hr_spatial_coarsen=3, ...)
hr_dh = DataHandler(f"{tdir}/hr_*.h5", hr_spatial_coarsen=1, ...)
container = DualRasterizer({'low_res': lr_dh, 'high_res': hr_dh}, ...)
containers.append(container)
bh = DualBatchHandler(train_containers=containers, ...)
#. Initialize a 3x, 5x, or 15x spatial enhancement model, with 14 output channels, and train for the desired number of epochs. (The 3x and 5x generator configs can be copied from the ``model_params.json`` files in each OEDI model `directory <https://data.openei.org/s3_viewer?bucket=nrel-pds-wtk&prefix=sup3rwind%2Fmodels%2Fsup3rwind_models_202401%2F>`__. The 15x generator config can be created from the OEDI model configs by changing the spatial enhancement factor or from the configs in the repo by changing the enhancement factor and the number of output channels).

.. code-block:: python
from sup3r.models import Sup3rGan
model = Sup3rGan(gen_layers="./gen_config.json", disc_layers="./disc_config.json", ...)
model.train(batch_handler, ...)
Sup3rWind Versions
-------------------

Expand Down
9 changes: 5 additions & 4 deletions sup3r/preprocessing/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,8 @@ class Sup3rDataset:
from the high_res non coarsened variable.
"""

DSET_NAMES = ('low_res', 'high_res', 'obs')

def __init__(
self,
**dsets: Mapping[str, Union[xr.Dataset, Sup3rX]],
Expand Down Expand Up @@ -185,7 +187,7 @@ def rewrap(self, data):
return data
if len(data) == 1:
return type(self)(high_res=data[0])
return type(self)(**dict(zip(['low_res', 'high_res', 'obs'], data)))
return type(self)(**dict(zip(self.DSET_NAMES, data)))

def sample(self, idx):
"""Get samples from ``self._ds`` members. idx should be either a tuple
Expand Down Expand Up @@ -369,16 +371,15 @@ def wrap(self, data):
if isinstance(data, dict):
data = Sup3rDataset(**data)

default_names = ['low_res', 'high_res', 'obs']
if isinstance(data, tuple) and len(data) > 1:
msg = (
f'{self.__class__.__name__}.data is being set with a '
f'{len(data)}-tuple without explicit dataset names. We will '
f'assume name ordering: {default_names[:len(data)]}'
f'assume name ordering: {Sup3rDataset.DSET_NAMES[:len(data)]}'
)
logger.warning(msg)
warn(msg)
data = Sup3rDataset(**dict(zip(default_names, data)))
data = Sup3rDataset(**dict(zip(Sup3rDataset.DSET_NAMES, data)))
elif not isinstance(data, Sup3rDataset):
name = getattr(data, 'name', None) or 'high_res'
data = Sup3rDataset(**{name: data})
Expand Down
9 changes: 2 additions & 7 deletions sup3r/preprocessing/batch_queues/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,7 @@
from sup3r.utilities.utilities import RANDOM_GENERATOR, Timer

if TYPE_CHECKING:
from sup3r.preprocessing.samplers import (
DualSampler,
DualSamplerWithObs,
Sampler,
)
from sup3r.preprocessing.samplers import DualSampler, Sampler

logger = logging.getLogger(__name__)

Expand All @@ -41,8 +37,7 @@ class AbstractBatchQueue(Collection, ABC):
def __init__(
self,
samplers: Union[
List['Sampler'], List['DualSampler'], List['DualSamplerWithObs']
],
List['Sampler'], List['DualSampler']],
batch_size: int = 16,
n_batches: int = 64,
s_enhance: int = 1,
Expand Down
2 changes: 1 addition & 1 deletion sup3r/preprocessing/cachers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,7 +330,7 @@ def write_h5(
]

if Dimension.TIME in data:
data[Dimension.TIME] = data[Dimension.TIME].astype(int)
data[Dimension.TIME] = data[Dimension.TIME].astype('int64')

for dset in [*coord_names, *features]:
data_var, chunksizes = cls.get_chunksizes(dset, data, chunks)
Expand Down
25 changes: 25 additions & 0 deletions sup3r/preprocessing/data_handlers/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,31 @@ def __init__(
Dictionary of additional keyword args for
:class:`~sup3r.preprocessing.rasterizers.Rasterizer`, used
specifically for rasterizing flattened data
Examples
--------
Extract windspeed at 40m and 80m above the ground from files for u/v at
10m and 100m. Windspeed will be interpolated from surrounding levels
using a log profile. ``dh`` will contain dask arrays of this data with
10x10x50 chunk sizes. Data will be cached to files named
'windspeed_40m.h5' and 'windspeed_80m.h5' in './cache_dir' with
5x5x10 chunks on disk.
>>> cache_chunks = {'south_north': 5, 'west_east': 5, 'time': 10}
>>> load_chunks = {'south_north': 10, 'west_east': 10, 'time': 50}
>>> grid_size = (50, 50)
>>> lower_left_coordinate = (39.7, -105.2)
>>> dh = DataHandler(
... file_paths=['./data_dir/u_10m.nc', './data_dir/u_100m.nc',
... './data_dir/v_10m.nc', './data_dir/v_100m.nc'],
... features=['windspeed_40m', 'windspeed_80m'],
... shape=grid_size, time_slice=slice(0, 100),
... target=lower_left_coordinate, hr_spatial_coarsen=2,
... chunks=load_chunks, interp_kwargs={'method': 'log'},
... cache_kwargs={'cache_pattern': './cache_dir/{feature}.h5',
... 'chunks': cache_chunks})
Derive more features from already initialized data handler:
>>> dh['windspeed_60m'] = dh.derive('windspeed_60m')
""" # pylint: disable=line-too-long

features = parse_to_list(features=features)
Expand Down
11 changes: 6 additions & 5 deletions sup3r/preprocessing/rasterizers/dual.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
datasets"""

import logging
from typing import Tuple, Union
from typing import Dict, Tuple, Union
from warnings import warn

import numpy as np
Expand Down Expand Up @@ -38,7 +38,7 @@ class DualRasterizer(Container):
@log_args
def __init__(
self,
data: Union[Sup3rDataset, Tuple[xr.Dataset, xr.Dataset]],
data: Union[Sup3rDataset, Tuple[xr.Dataset, xr.Dataset], Dict[str, xr.Dataset]],

Check failure on line 41 in sup3r/preprocessing/rasterizers/dual.py

View workflow job for this annotation

GitHub Actions / Ruff

Ruff (E501)

sup3r/preprocessing/rasterizers/dual.py:41:80: E501 Line too long (88 > 79)
regrid_workers=1,
regrid_lr=True,
s_enhance=1,
Expand All @@ -51,7 +51,8 @@ def __init__(
Parameters
----------
data : Sup3rDataset | Tuple[xr.Dataset, xr.Dataset]
data : Sup3rDataset | Tuple[xr.Dataset, xr.Dataset] |
Dict[str, xr.Dataset]
A tuple of xr.Dataset instances. The first must be low-res
and the second must be high-res data
regrid_workers : int | None
Expand All @@ -78,9 +79,9 @@ def __init__(
if isinstance(data, tuple):
data = {'low_res': data[0], 'high_res': data[1]}
if isinstance(data, dict):
data = Sup3rDataset(data)
data = Sup3rDataset(**data)
msg = (
'The DualRasterizer requires either a data tuple with two '
'The DualRasterizer requires a data tuple or dictionary with two '
'members, low and high resolution in that order, or a '
f'Sup3rDataset instance. Received {type(data)}.'
)
Expand Down
15 changes: 7 additions & 8 deletions sup3r/preprocessing/samplers/dual.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,10 +63,11 @@ def __init__(
'with `.low_res` and `.high_res` data members, and optionally an '
'`.obs` member, in that order'
)
check = hasattr(data, 'low_res') and hasattr(data, 'high_res')
check = check and data.low_res == data[0] and data.high_res == data[1]
if len(data) == 2:
check = check and (hasattr(data, 'obs') and data.obs == data[2])
dnames = ['low_res', 'high_res', 'obs'][:len(data)]
check = (
hasattr(data, dname) and getattr(data, dname) == data[i]
for i, dname in enumerate(dnames)
)
assert check, msg

super().__init__(
Expand Down Expand Up @@ -146,7 +147,5 @@ def get_sample_index(self, n_obs=None):
]
hr_index = (*hr_index, self.hr_features)

sample_index = (lr_index, hr_index)
if hasattr(self.data, 'obs'):
sample_index += (hr_index,)
return sample_index
sample_index = (lr_index, hr_index, hr_index)
return sample_index[:len(self.data)]
2 changes: 1 addition & 1 deletion sup3r/utilities/utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def preprocess_datasets(dset):
dset.indexes['time'], 'to_datetimeindex'
):
dset['time'] = dset.indexes['time'].to_datetimeindex()
ti = dset['time'].astype(int)
ti = dset['time'].astype('int64')
dset['time'] = ti
if 'latitude' in dset.dims:
dset = dset.swap_dims({'latitude': 'south_north'})
Expand Down
59 changes: 0 additions & 59 deletions tests/data/extract_raster_wtk.py

This file was deleted.

0 comments on commit 516f72a

Please sign in to comment.