Skip to content

Commit

Permalink
Merge pull request #34 from digitalearthpacific/fix-local-writing
Browse files Browse the repository at this point in the history
Fix local writing
  • Loading branch information
alexgleith authored Dec 4, 2023
2 parents cad7afb + c00933b commit 1b22cec
Show file tree
Hide file tree
Showing 3 changed files with 100 additions and 23 deletions.
2 changes: 1 addition & 1 deletion dep_tools/namers.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ class LocalPath(DepItemPath):
def __init__(self, local_folder: str, **kwargs):
# Need to create an abc for DepItemPath and drop this
super().__init__(**kwargs)
self._folder_prefix = local_folder
self._folder_prefix = f"{local_folder}/dep_{self.sensor}_{self.dataset_id}/{self.version}"

def _folder(self, item_id) -> str:
return self._folder_prefix
2 changes: 1 addition & 1 deletion dep_tools/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ def write_to_blob_storage(
write_args: Dict = dict(),
overwrite: bool = True,
use_odc_writer: bool = False,
client: ContainerClient = None,
client: ContainerClient | None = None,
**kwargs,
) -> None:
# Allowing for a shared container client, which might be
Expand Down
119 changes: 98 additions & 21 deletions dep_tools/writers.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from abc import ABC, abstractmethod
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass, field
from dataclasses import field
from functools import partial
from typing import Callable, Dict, Hashable, List, Union

Expand All @@ -9,6 +9,8 @@
from urlpath import URL
from xarray import DataArray, Dataset

from azure.storage.blob import ContainerClient

from .azure import get_container_client
from .namers import DepItemPath
from .stac_utils import write_stac_blob_storage, write_stac_local
Expand All @@ -24,16 +26,20 @@ def write(self, data, id) -> str:
pass


@dataclass
class XrWriterMixin(object):
itempath: DepItemPath
overwrite: bool = False
convert_to_int16: bool = True
output_value_multiplier: int = 10000
scale_int16s: bool = False
output_nodata: int = -32767
extra_attrs: Dict = field(default_factory=dict)
use_odc_writer: bool = True
def __init__(
self,
convert_to_int16: bool = True,
output_value_multiplier: int = 10000,
scale_int16s: bool = False,
output_nodata: int = -32767,
extra_attrs: Dict = field(default_factory=dict),
):
self.convert_to_int16 = convert_to_int16
self.output_value_multiplier = output_value_multiplier
self.scale_int16s = scale_int16s
self.output_nodata = output_nodata
self.extra_attrs = extra_attrs

def prep(self, xr: Union[DataArray, Dataset]):
xr.attrs.update(self.extra_attrs)
Expand All @@ -47,18 +53,41 @@ def prep(self, xr: Union[DataArray, Dataset]):
return xr


@dataclass
class DsWriter(XrWriterMixin, Writer):
write_function: Callable = write_to_blob_storage
write_multithreaded: bool = False
write_stac_function: Callable = write_stac_blob_storage
write_stac: bool = True
def __init__(
self,
itempath: DepItemPath,
use_odc_writer: bool = True,
overwrite: bool = False,
write_function: Callable = write_to_blob_storage,
write_stac_function: Callable = write_stac_blob_storage,
write_stac: bool = True,
write_multithreaded: bool = False,
convert_to_int16: bool = True,
output_value_multiplier: int = 10000,
scale_int16s: bool = False,
output_nodata: int = -32767,
extra_attrs: Dict = field(default_factory=dict),
):
self.itempath = itempath
self.use_odc_writer = use_odc_writer
self.overwrite = overwrite
self.write_function = write_function
self.write_stac_function = write_stac_function
self.write_stac = write_stac
self.write_multithreaded = write_multithreaded
super().__init__(
convert_to_int16,
output_value_multiplier,
scale_int16s,
output_nodata,
extra_attrs,
)

def write(self, xr: Dataset, item_id: str) -> str | List:
xr = super().prep(xr)
paths = []
assets = {}
client = get_container_client()

def get_write_partial(variable: Hashable) -> Callable:
output_da = xr[variable].squeeze()
Expand All @@ -72,7 +101,6 @@ def get_write_partial(variable: Hashable) -> Callable:
write_args=dict(driver="COG"),
overwrite=self.overwrite,
use_odc_writer=self.use_odc_writer,
client=client,
)

if self.write_multithreaded:
Expand Down Expand Up @@ -118,14 +146,63 @@ def get_write_partial(variable: Hashable) -> Callable:


class LocalDsWriter(DsWriter):
def __init__(self, use_odc_writer: bool = False, **kwargs):
def __init__(
self,
itempath: DepItemPath,
use_odc_writer: bool = True,
overwrite: bool = False,
write_stac: bool = True,
write_multithreaded: bool = False,
convert_to_int16: bool = True,
output_value_multiplier: int = 10000,
scale_int16s: bool = False,
output_nodata: int = -32767,
extra_attrs: Dict = field(default_factory=dict),
):
super().__init__(
itempath=itempath,
use_odc_writer=use_odc_writer,
overwrite=overwrite,
write_function=write_to_local_storage,
write_stac_function=write_stac_local,
use_odc_writer=use_odc_writer,
**kwargs,
write_stac=write_stac,
write_multithreaded=write_multithreaded,
convert_to_int16=convert_to_int16,
output_value_multiplier=output_value_multiplier,
scale_int16s=scale_int16s,
output_nodata=output_nodata,
extra_attrs=extra_attrs,
)


class AzureDsWriter(DsWriter):
pass
def __init__(
self,
itempath: DepItemPath,
client: ContainerClient | None = None,
use_odc_writer: bool = True,
overwrite: bool = False,
write_stac: bool = True,
write_multithreaded: bool = False,
convert_to_int16: bool = True,
output_value_multiplier: int = 10000,
scale_int16s: bool = False,
output_nodata: int = -32767,
extra_attrs: Dict = field(default_factory=dict),
):
self.client = get_container_client() if client is None else client
write_function = partial(write_to_blob_storage, client=client)
super().__init__(
itempath=itempath,
use_odc_writer=use_odc_writer,
overwrite=overwrite,
write_function=write_function,
write_stac_function=write_stac_blob_storage,
write_stac=write_stac,
write_multithreaded=write_multithreaded,
convert_to_int16=convert_to_int16,
output_value_multiplier=output_value_multiplier,
scale_int16s=scale_int16s,
output_nodata=output_nodata,
extra_attrs=extra_attrs,
)

0 comments on commit 1b22cec

Please sign in to comment.