diff --git a/dep_tools/namers.py b/dep_tools/namers.py index e16747f..916ec59 100644 --- a/dep_tools/namers.py +++ b/dep_tools/namers.py @@ -69,7 +69,7 @@ class LocalPath(DepItemPath): def __init__(self, local_folder: str, **kwargs): # Need to create an abc for DepItemPath and drop this super().__init__(**kwargs) - self._folder_prefix = local_folder + self._folder_prefix = f"{local_folder}/dep_{self.sensor}_{self.dataset_id}/{self.version}" def _folder(self, item_id) -> str: return self._folder_prefix diff --git a/dep_tools/utils.py b/dep_tools/utils.py index 1f7a6b9..81494d3 100644 --- a/dep_tools/utils.py +++ b/dep_tools/utils.py @@ -155,7 +155,7 @@ def write_to_blob_storage( write_args: Dict = dict(), overwrite: bool = True, use_odc_writer: bool = False, - client: ContainerClient = None, + client: ContainerClient | None = None, **kwargs, ) -> None: # Allowing for a shared container client, which might be diff --git a/dep_tools/writers.py b/dep_tools/writers.py index d8ccc23..d68d047 100644 --- a/dep_tools/writers.py +++ b/dep_tools/writers.py @@ -1,6 +1,6 @@ from abc import ABC, abstractmethod from concurrent.futures import ThreadPoolExecutor -from dataclasses import dataclass, field +from dataclasses import field from functools import partial from typing import Callable, Dict, Hashable, List, Union @@ -9,6 +9,8 @@ from urlpath import URL from xarray import DataArray, Dataset +from azure.storage.blob import ContainerClient + from .azure import get_container_client from .namers import DepItemPath from .stac_utils import write_stac_blob_storage, write_stac_local @@ -24,16 +26,20 @@ def write(self, data, id) -> str: pass -@dataclass class XrWriterMixin(object): - itempath: DepItemPath - overwrite: bool = False - convert_to_int16: bool = True - output_value_multiplier: int = 10000 - scale_int16s: bool = False - output_nodata: int = -32767 - extra_attrs: Dict = field(default_factory=dict) - use_odc_writer: bool = True + def __init__( + self, + convert_to_int16: bool = True, + output_value_multiplier: int = 10000, + scale_int16s: bool = False, + output_nodata: int = -32767, + extra_attrs: Dict = field(default_factory=dict), + ): + self.convert_to_int16 = convert_to_int16 + self.output_value_multiplier = output_value_multiplier + self.scale_int16s = scale_int16s + self.output_nodata = output_nodata + self.extra_attrs = extra_attrs def prep(self, xr: Union[DataArray, Dataset]): xr.attrs.update(self.extra_attrs) @@ -47,18 +53,41 @@ def prep(self, xr: Union[DataArray, Dataset]): return xr -@dataclass class DsWriter(XrWriterMixin, Writer): - write_function: Callable = write_to_blob_storage - write_multithreaded: bool = False - write_stac_function: Callable = write_stac_blob_storage - write_stac: bool = True + def __init__( + self, + itempath: DepItemPath, + use_odc_writer: bool = True, + overwrite: bool = False, + write_function: Callable = write_to_blob_storage, + write_stac_function: Callable = write_stac_blob_storage, + write_stac: bool = True, + write_multithreaded: bool = False, + convert_to_int16: bool = True, + output_value_multiplier: int = 10000, + scale_int16s: bool = False, + output_nodata: int = -32767, + extra_attrs: Dict = field(default_factory=dict), + ): + self.itempath = itempath + self.use_odc_writer = use_odc_writer + self.overwrite = overwrite + self.write_function = write_function + self.write_stac_function = write_stac_function + self.write_stac = write_stac + self.write_multithreaded = write_multithreaded + super().__init__( + convert_to_int16, + output_value_multiplier, + scale_int16s, + output_nodata, + extra_attrs, + ) def write(self, xr: Dataset, item_id: str) -> str | List: xr = super().prep(xr) paths = [] assets = {} - client = get_container_client() def get_write_partial(variable: Hashable) -> Callable: output_da = xr[variable].squeeze() @@ -72,7 +101,6 @@ def get_write_partial(variable: Hashable) -> Callable: write_args=dict(driver="COG"), overwrite=self.overwrite, use_odc_writer=self.use_odc_writer, - client=client, ) if self.write_multithreaded: @@ -118,14 +146,63 @@ def get_write_partial(variable: Hashable) -> Callable: class LocalDsWriter(DsWriter): - def __init__(self, use_odc_writer: bool = False, **kwargs): + def __init__( + self, + itempath: DepItemPath, + use_odc_writer: bool = True, + overwrite: bool = False, + write_stac: bool = True, + write_multithreaded: bool = False, + convert_to_int16: bool = True, + output_value_multiplier: int = 10000, + scale_int16s: bool = False, + output_nodata: int = -32767, + extra_attrs: Dict = field(default_factory=dict), + ): super().__init__( + itempath=itempath, + use_odc_writer=use_odc_writer, + overwrite=overwrite, write_function=write_to_local_storage, write_stac_function=write_stac_local, - use_odc_writer=use_odc_writer, - **kwargs, + write_stac=write_stac, + write_multithreaded=write_multithreaded, + convert_to_int16=convert_to_int16, + output_value_multiplier=output_value_multiplier, + scale_int16s=scale_int16s, + output_nodata=output_nodata, + extra_attrs=extra_attrs, ) class AzureDsWriter(DsWriter): - pass + def __init__( + self, + itempath: DepItemPath, + client: ContainerClient | None = None, + use_odc_writer: bool = True, + overwrite: bool = False, + write_stac: bool = True, + write_multithreaded: bool = False, + convert_to_int16: bool = True, + output_value_multiplier: int = 10000, + scale_int16s: bool = False, + output_nodata: int = -32767, + extra_attrs: Dict = field(default_factory=dict), + ): + self.client = get_container_client() if client is None else client + write_function = partial(write_to_blob_storage, client=client) + super().__init__( + itempath=itempath, + use_odc_writer=use_odc_writer, + overwrite=overwrite, + write_function=write_function, + write_stac_function=write_stac_blob_storage, + write_stac=write_stac, + write_multithreaded=write_multithreaded, + convert_to_int16=convert_to_int16, + output_value_multiplier=output_value_multiplier, + scale_int16s=scale_int16s, + output_nodata=output_nodata, + extra_attrs=extra_attrs, + )