Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enhancement to ERA5 Data Retrieval and Download Process #397

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
186 changes: 175 additions & 11 deletions atlite/datasets/era5.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,18 @@
https://confluence.ecmwf.int/display/CKB/ERA5%3A+data+documentation
"""

import hashlib
import logging
import os
import time
import warnings
import weakref
from tempfile import mkstemp

import cdsapi
import numpy as np
import pandas as pd
import requests
import xarray as xr
from dask import compute, delayed
from dask.array import arctan2, sqrt
Expand All @@ -25,6 +28,10 @@
from atlite.gis import maybe_swap_spatial_dims
from atlite.pv.solar_position import SolarPosition

download_status = {}
file_aliases = {}
MAX_DISPLAY_FILES = 3

# Null context for running a with statements wihout any context
try:
from contextlib import nullcontext
Expand Down Expand Up @@ -96,6 +103,26 @@ def _rename_and_clean_coords(ds, add_lon_lat=True):
ds = maybe_swap_spatial_dims(ds)
if add_lon_lat:
ds = ds.assign_coords(lon=ds.coords["x"], lat=ds.coords["y"])

# Combine ERA5 and ERA5T data into a single dimension.
# See https://github.com/PyPSA/atlite/issues/190
if "expver" in ds.coords:
unique_expver = np.unique(ds["expver"].values)
if len(unique_expver) > 1:
expver_dim = xr.DataArray(
unique_expver, dims=["expver"], coords={"expver": unique_expver}
)
ds = (
ds.assign_coords({"expver_dim": expver_dim})
.drop_vars("expver")
.rename({"expver_dim": "expver"})
.set_index(expver="expver")
)
for var in ds.data_vars:
ds[var] = ds[var].expand_dims("expver")
# expver=1 is ERA5 data, expver=5 is ERA5T data This combines both
# by filling in NaNs from ERA5 data with values from ERA5T.
ds = ds.sel(expver="0001").combine_first(ds.sel(expver="0005"))
ds = ds.drop_vars(["expver", "number"], errors="ignore")

return ds
Expand Down Expand Up @@ -323,6 +350,125 @@ def noisy_unlink(path):
logger.error(f"Unable to delete file {path}, as it is still in use.")


def get_cache_filename(request, cache_dir):
"""
Generate a unique cache filename based on the request parameters.
"""
# Serialize the request dictionary into a sorted string to ensure consistency
request_str = "_".join(
f"{key}-{sorted(value) if isinstance(value, list) else value}"
for key, value in sorted(request.items())
)
# Generate a hash of the request string
request_hash = hashlib.md5(request_str.encode("utf-8")).hexdigest()
# Use the first 8 characters of the hash for brevity
return f"{request_hash}.nc"


def custom_download(url, size, target, lock, filename):
"""
Optimized download function that uses a simple progress bar and removes
completed files from the display.
"""
if target is None:
target = url.split("/")[-1]

# Assign a short alias for the filename (e.g. f1, f2, ...)
file_number = len(file_aliases) + 1
file_aliases[filename] = f"f{file_number}"

logging.info(f"Downloading {filename} to {target} ({size} bytes)")
start = time.time()

mode = "wb"
total = 0
sleep = 10
tries = 0
headers = None

while tries < 5:
r = requests.get(url, stream=True, headers=headers)
try:
r.raise_for_status()

with open(target, mode) as f:
for chunk in r.iter_content(chunk_size=1024 * 1024):
if chunk:
f.write(chunk)
total += len(chunk)
with lock:
download_status[filename] = total / size * 100
update_progress_bar()

except requests.exceptions.ConnectionError as e:
logging.error(f"Download interrupted: {e}")
break
finally:
r.close()

if total >= size:
break

logging.error(f"Download incomplete, downloaded {total} bytes out of {size}")
logging.warning(f"Sleeping {sleep} seconds")
time.sleep(sleep)
mode = "ab"
total = os.path.getsize(target)
sleep *= 1.5
headers = {"Range": f"bytes={total}-"}
tries += 1

if total != size:
raise Exception(f"Download failed: downloaded {total} bytes out of {size}")

elapsed = time.time() - start
if elapsed:
logging.info(f"Download rate {total / elapsed:.2f} bytes/s")

return target


def update_progress_bar():
"""
Update a progress bar that shows the percentage of all files being
downloaded.

Files that have reached 100% are removed from the display. Only
short aliases are displayed.
"""
completed_files = [
file for file, progress in download_status.items() if progress >= 100
]

# Remove completed files from the progress dictionary
for file in completed_files:
del download_status[file]
del file_aliases[file] # Remove alias as well

if not download_status:
# If no active downloads, clear the progress bar
print("\r", end="")
return

# Only display the top N files to avoid multi-line output
displayed_files = list(download_status.items())[:MAX_DISPLAY_FILES]

# Create progress string using the short aliases
progress = " | ".join(
[
f"{file_aliases[file]}: {int(progress)}%"
for file, progress in displayed_files
]
)

# If there are more files, show a summary
if len(download_status) > MAX_DISPLAY_FILES:
progress += f" | ... and {len(download_status) - MAX_DISPLAY_FILES} more"

# Use \r to overwrite the same line
print(f"\r{progress}", end="")


def retrieve_data(product, chunks=None, tmpdir=None, lock=None, **updates):
"""
Download data like ERA5 from the Climate Data Store (CDS).
Expand All @@ -337,6 +483,21 @@ def retrieve_data(product, chunks=None, tmpdir=None, lock=None, **updates):
request
), "Need to specify at least 'variable', 'year' and 'month'"

# Use tmpdir for cache directory; if not provided, use current working directory
if tmpdir is None:
tmpdir = os.getcwd()
cache_dir = tmpdir
os.makedirs(cache_dir, exist_ok=True)

# Generate cache filename based on request
cache_filename = get_cache_filename(request, cache_dir)
cache_filepath = os.path.join(cache_dir, cache_filename)

if os.path.exists(cache_filepath):
logging.info(f"Using cached file for request: {cache_filename}")
ds = xr.open_dataset(cache_filepath, chunks=chunks or {})
return ds

client = cdsapi.Client(
info_callback=logger.debug, debug=logging.DEBUG >= logging.root.level
)
Expand All @@ -349,25 +510,28 @@ def retrieve_data(product, chunks=None, tmpdir=None, lock=None, **updates):
fd, target = mkstemp(suffix=".nc", dir=tmpdir)
os.close(fd)

# Inform user about data being downloaded as "* variable (year-month)"
timestr = f"{request['year']}-{request['month']}"
variables = atleast_1d(request["variable"])
varstr = "\n\t".join([f"{v} ({timestr})" for v in variables])
logger.info(f"CDS: Downloading variables\n\t{varstr}\n")
result.download(target)

ds = xr.open_dataset(target, chunks=chunks or {})
# Inform user about data being downloaded as "* variable (year-month)"
timestr = f"{request['year']}-{request['month']}"
variables = atleast_1d(request["variable"])
varstr = "\n\t".join([f"{v} ({timestr})" for v in variables])
filename = f"{variables[0]}_{timestr}.nc"
logger.info(f"CDS: Downloading variables\n\t{varstr}\n")
custom_download(result.location, result.content_length, target, lock, filename)

# Move the downloaded file to cache directory
os.rename(target, cache_filepath)
ds = xr.open_dataset(cache_filepath, chunks=chunks or {})
if tmpdir is None:
logger.debug(f"Adding finalizer for {target}")
weakref.finalize(ds._file_obj._manager, noisy_unlink, target)
logger.debug(f"Adding finalizer for {cache_filepath}")
weakref.finalize(ds._file_obj._manager, noisy_unlink, cache_filepath)

# Remove default encoding we get from CDSAPI, which can lead to NaN values after loading with subsequent
# saving due to how xarray handles netcdf compression (only float encoded as short int seem affected)
# Fixes issue by keeping "float32" encoded as "float32" instead of internally saving as "short int", see:
# https://stackoverflow.com/questions/75755441/why-does-saving-to-netcdf-without-encoding-change-some-values-to-nan
# and hopefully fixed soon (could then remove), see https://github.com/pydata/xarray/issues/7691
for v in ds.data_vars:
if ds[v].encoding["dtype"] == "int16":
if ds[v].encoding.get("dtype") == "int16":
ds[v].encoding.clear()

return ds
Expand Down