From b7fbf934ef0bbd3d3b9604a489b66dd6fb661806 Mon Sep 17 00:00:00 2001 From: Alex Leith Date: Fri, 1 Dec 2023 13:09:11 +1100 Subject: [PATCH 1/3] Add more info to the folder path --- dep_tools/namers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dep_tools/namers.py b/dep_tools/namers.py index e16747f..916ec59 100644 --- a/dep_tools/namers.py +++ b/dep_tools/namers.py @@ -69,7 +69,7 @@ class LocalPath(DepItemPath): def __init__(self, local_folder: str, **kwargs): # Need to create an abc for DepItemPath and drop this super().__init__(**kwargs) - self._folder_prefix = local_folder + self._folder_prefix = f"{local_folder}/dep_{self.sensor}_{self.dataset_id}/{self.version}" def _folder(self, item_id) -> str: return self._folder_prefix From 94b8888fc26cd8671a79ebcf8558038d8c9e0093 Mon Sep 17 00:00:00 2001 From: Alex Leith Date: Fri, 1 Dec 2023 13:09:24 +1100 Subject: [PATCH 2/3] Do not require a container client --- dep_tools/writers.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/dep_tools/writers.py b/dep_tools/writers.py index d8ccc23..a5020f6 100644 --- a/dep_tools/writers.py +++ b/dep_tools/writers.py @@ -9,6 +9,8 @@ from urlpath import URL from xarray import DataArray, Dataset +from azure.storage.blob import ContainerClient + from .azure import get_container_client from .namers import DepItemPath from .stac_utils import write_stac_blob_storage, write_stac_local @@ -34,6 +36,7 @@ class XrWriterMixin(object): output_nodata: int = -32767 extra_attrs: Dict = field(default_factory=dict) use_odc_writer: bool = True + client: ContainerClient = None def prep(self, xr: Union[DataArray, Dataset]): xr.attrs.update(self.extra_attrs) @@ -50,15 +53,14 @@ def prep(self, xr: Union[DataArray, Dataset]): @dataclass class DsWriter(XrWriterMixin, Writer): write_function: Callable = write_to_blob_storage - write_multithreaded: bool = False write_stac_function: Callable = write_stac_blob_storage write_stac: bool = True + write_multithreaded: bool = False def write(self, xr: Dataset, item_id: str) -> str | List: xr = super().prep(xr) paths = [] assets = {} - client = get_container_client() def get_write_partial(variable: Hashable) -> Callable: output_da = xr[variable].squeeze() @@ -72,7 +74,7 @@ def get_write_partial(variable: Hashable) -> Callable: write_args=dict(driver="COG"), overwrite=self.overwrite, use_odc_writer=self.use_odc_writer, - client=client, + client=self.client, ) if self.write_multithreaded: @@ -128,4 +130,10 @@ def __init__(self, use_odc_writer: bool = False, **kwargs): class AzureDsWriter(DsWriter): - pass + def __init__(self, **kwargs): + super().__init__( + write_function=write_to_blob_storage, + write_stac_function=write_stac_blob_storage, + **kwargs, + ) + self.client = get_container_client() From c00933bc03ef20485aae4af17674ff189fc928ee Mon Sep 17 00:00:00 2001 From: Jesse Anderson Date: Fri, 1 Dec 2023 13:55:11 -0700 Subject: [PATCH 3/3] clean up to help Alex --- dep_tools/utils.py | 2 +- dep_tools/writers.py | 117 ++++++++++++++++++++++++++++++++++--------- 2 files changed, 94 insertions(+), 25 deletions(-) diff --git a/dep_tools/utils.py b/dep_tools/utils.py index 1f7a6b9..81494d3 100644 --- a/dep_tools/utils.py +++ b/dep_tools/utils.py @@ -155,7 +155,7 @@ def write_to_blob_storage( write_args: Dict = dict(), overwrite: bool = True, use_odc_writer: bool = False, - client: ContainerClient = None, + client: ContainerClient | None = None, **kwargs, ) -> None: # Allowing for a shared container client, which might be diff --git a/dep_tools/writers.py b/dep_tools/writers.py index a5020f6..d68d047 100644 --- a/dep_tools/writers.py +++ b/dep_tools/writers.py @@ -1,6 +1,6 @@ from abc import ABC, abstractmethod from concurrent.futures import ThreadPoolExecutor -from dataclasses import dataclass, field +from dataclasses import field from functools import partial from typing import Callable, Dict, Hashable, List, Union @@ -26,17 +26,20 @@ def write(self, data, id) -> str: pass -@dataclass class XrWriterMixin(object): - itempath: DepItemPath - overwrite: bool = False - convert_to_int16: bool = True - output_value_multiplier: int = 10000 - scale_int16s: bool = False - output_nodata: int = -32767 - extra_attrs: Dict = field(default_factory=dict) - use_odc_writer: bool = True - client: ContainerClient = None + def __init__( + self, + convert_to_int16: bool = True, + output_value_multiplier: int = 10000, + scale_int16s: bool = False, + output_nodata: int = -32767, + extra_attrs: Dict = field(default_factory=dict), + ): + self.convert_to_int16 = convert_to_int16 + self.output_value_multiplier = output_value_multiplier + self.scale_int16s = scale_int16s + self.output_nodata = output_nodata + self.extra_attrs = extra_attrs def prep(self, xr: Union[DataArray, Dataset]): xr.attrs.update(self.extra_attrs) @@ -50,12 +53,36 @@ def prep(self, xr: Union[DataArray, Dataset]): return xr -@dataclass class DsWriter(XrWriterMixin, Writer): - write_function: Callable = write_to_blob_storage - write_stac_function: Callable = write_stac_blob_storage - write_stac: bool = True - write_multithreaded: bool = False + def __init__( + self, + itempath: DepItemPath, + use_odc_writer: bool = True, + overwrite: bool = False, + write_function: Callable = write_to_blob_storage, + write_stac_function: Callable = write_stac_blob_storage, + write_stac: bool = True, + write_multithreaded: bool = False, + convert_to_int16: bool = True, + output_value_multiplier: int = 10000, + scale_int16s: bool = False, + output_nodata: int = -32767, + extra_attrs: Dict = field(default_factory=dict), + ): + self.itempath = itempath + self.use_odc_writer = use_odc_writer + self.overwrite = overwrite + self.write_function = write_function + self.write_stac_function = write_stac_function + self.write_stac = write_stac + self.write_multithreaded = write_multithreaded + super().__init__( + convert_to_int16, + output_value_multiplier, + scale_int16s, + output_nodata, + extra_attrs, + ) def write(self, xr: Dataset, item_id: str) -> str | List: xr = super().prep(xr) @@ -74,7 +101,6 @@ def get_write_partial(variable: Hashable) -> Callable: write_args=dict(driver="COG"), overwrite=self.overwrite, use_odc_writer=self.use_odc_writer, - client=self.client, ) if self.write_multithreaded: @@ -120,20 +146,63 @@ def get_write_partial(variable: Hashable) -> Callable: class LocalDsWriter(DsWriter): - def __init__(self, use_odc_writer: bool = False, **kwargs): + def __init__( + self, + itempath: DepItemPath, + use_odc_writer: bool = True, + overwrite: bool = False, + write_stac: bool = True, + write_multithreaded: bool = False, + convert_to_int16: bool = True, + output_value_multiplier: int = 10000, + scale_int16s: bool = False, + output_nodata: int = -32767, + extra_attrs: Dict = field(default_factory=dict), + ): super().__init__( + itempath=itempath, + use_odc_writer=use_odc_writer, + overwrite=overwrite, write_function=write_to_local_storage, write_stac_function=write_stac_local, - use_odc_writer=use_odc_writer, - **kwargs, + write_stac=write_stac, + write_multithreaded=write_multithreaded, + convert_to_int16=convert_to_int16, + output_value_multiplier=output_value_multiplier, + scale_int16s=scale_int16s, + output_nodata=output_nodata, + extra_attrs=extra_attrs, ) class AzureDsWriter(DsWriter): - def __init__(self, **kwargs): + def __init__( + self, + itempath: DepItemPath, + client: ContainerClient | None = None, + use_odc_writer: bool = True, + overwrite: bool = False, + write_stac: bool = True, + write_multithreaded: bool = False, + convert_to_int16: bool = True, + output_value_multiplier: int = 10000, + scale_int16s: bool = False, + output_nodata: int = -32767, + extra_attrs: Dict = field(default_factory=dict), + ): + self.client = get_container_client() if client is None else client + write_function = partial(write_to_blob_storage, client=client) super().__init__( - write_function=write_to_blob_storage, + itempath=itempath, + use_odc_writer=use_odc_writer, + overwrite=overwrite, + write_function=write_function, write_stac_function=write_stac_blob_storage, - **kwargs, + write_stac=write_stac, + write_multithreaded=write_multithreaded, + convert_to_int16=convert_to_int16, + output_value_multiplier=output_value_multiplier, + scale_int16s=scale_int16s, + output_nodata=output_nodata, + extra_attrs=extra_attrs, ) - self.client = get_container_client()