Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor index types #405

Draft
wants to merge 14 commits into
base: beam-refactor
Choose a base branch
from
10 changes: 5 additions & 5 deletions pangeo_forge_recipes/combiners.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,24 +6,24 @@
import apache_beam as beam

from .aggregation import XarrayCombineAccumulator, XarraySchema
from .patterns import CombineOp, DimKey, Index
from .types import CombineOp, Dimension, Index


@dataclass
class CombineXarraySchemas(beam.CombineFn):
"""A beam ``CombineFn`` which we can use to combine multiple xarray schemas
along a single dimension

:param dim_key: The dimension along which to combine
:param dimension: The dimension along which to combine
"""

dim_key: DimKey
dimension: Dimension

def get_position(self, index: Index) -> int:
return index[self.dim_key].position
return index[self.dimension].value

def create_accumulator(self) -> XarrayCombineAccumulator:
concat_dim = self.dim_key.name if self.dim_key.operation == CombineOp.CONCAT else None
concat_dim = self.dimension.name if self.dimension.operation == CombineOp.CONCAT else None
return XarrayCombineAccumulator(concat_dim=concat_dim)

def add_input(self, accumulator: XarrayCombineAccumulator, item: Tuple[Index, XarraySchema]):
Expand Down
120 changes: 28 additions & 92 deletions pangeo_forge_recipes/patterns.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,18 +9,22 @@
from typing import Any, Callable, ClassVar, Dict, Iterator, List, Optional, Sequence, Tuple, Union

from .serialization import dict_drop_empty, dict_to_sha256
from .types import CombineOp, Dimension, Index, IndexedPosition, Position


class CombineOp(Enum):
"""Used to uniquely identify different combine operations across Pangeo Forge Recipes."""
@dataclass(frozen=True)
class CombineDim:
name: str
operation: ClassVar[CombineOp]
keys: Sequence[Any] = field(repr=False)

MERGE = 1
CONCAT = 2
SUBSET = 3
@property
def dimension(self):
return Dimension(self.name, self.operation)


@dataclass(frozen=True)
class ConcatDim:
class ConcatDim(CombineDim):
"""Represents a concatenation operation across a dimension of a FilePattern.

:param name: The name of the dimension we are concatenating over. For
Expand All @@ -34,14 +38,12 @@ class ConcatDim:
provide a fast path for recipes.
"""

name: str # should match the actual dimension name
keys: Sequence[Any] = field(repr=False)
nitems_per_file: Optional[int] = None
operation: ClassVar[CombineOp] = CombineOp.CONCAT


@dataclass(frozen=True)
class MergeDim:
class MergeDim(CombineDim):
"""Represents a merge operation across a dimension of a FilePattern.

:param name: The name of the dimension we are are merging over. The actual
Expand All @@ -52,89 +54,20 @@ class MergeDim:
the file name.
"""

name: str
keys: Sequence[Any] = field(repr=False)
operation: ClassVar[CombineOp] = CombineOp.MERGE


# would it be simpler to just use a tuple?
@dataclass(frozen=True, order=True)
class DimKey:
"""
:param name: The name of the dimension we are combining over.
:param operation: What type of combination this is (merge or concat)
"""

name: str
operation: CombineOp


@dataclass(frozen=True, order=True)
class DimVal:
"""
:param position: Where this item lies within the sequence.
:param start: Where the starting array index for the item.
:param stop: The ending array index for the item.
"""

position: int
start: Optional[int] = None
stop: Optional[int] = None


# Alternative way of specifying type
# Index = dict[DimKey, DimVal]


class Index(Dict[DimKey, DimVal]):
"""An Index is a special sort of dictionary which describes a position within
a multidimensional set.

- The key is a :class:`DimKey` which tells us which dimension we are addressing.
- The value is a :class:`DimVal` which tells us where we are within that dimension.

This object is hashable and deterministically serializable.
"""

def __hash__(self):
return hash(tuple(self.__getstate__()))

def __getstate__(self):
return sorted(self.items())

def __setstate__(self, state):
self.__init__({k: v for k, v in state})

def find_concat_dim(self, dim_name: str):
possible_concat_dims = [
d for d in self if (d.name == dim_name and d.operation == CombineOp.CONCAT)
]
if len(possible_concat_dims) > 1:
raise ValueError(
f"Found {len(possible_concat_dims)} concat dims named {dim_name} "
f"in the index {self}."
)
elif len(possible_concat_dims) == 0:
return None
else:
key = possible_concat_dims[0]
return self[key]


CombineDim = Union[MergeDim, ConcatDim]


def augment_index_with_start_stop(dim_val: DimVal, item_lens: List[int]) -> DimVal:
def augment_index_with_start_stop(position: Position, item_lens: List[int]) -> IndexedPosition:
"""Take an index _without_ start / stop and add them based on the lens defined in sequence_lens.

:param index: The ``DimIndex`` instance to augment.
:param item_lens: A list of integer lengths for all items in the sequence.
"""

start = sum(item_lens[: dim_val.position])
stop = start + item_lens[dim_val.position]

return DimVal(dim_val.position, start, stop)
if position.indexed:
raise ValueError("This position is already indexed")
start = sum(item_lens[: position.value])
return IndexedPosition(start)


class AutoName(Enum):
Expand Down Expand Up @@ -239,31 +172,34 @@ def concat_sequence_lens(self) -> Dict[str, Optional[int]]:
}

@property
def combine_dim_keys(self) -> List[DimKey]:
return [DimKey(dim.name, dim.operation) for dim in self.combine_dims]
def combine_dim_keys(self) -> List[Dimension]:
return [Dimension(dim.name, dim.operation) for dim in self.combine_dims]

def __getitem__(self, indexer: Index) -> str:
"""Get a filename path for a particular key."""
assert len(indexer) == len(self.combine_dims)
format_function_kwargs = {}
for dimkey, dimval in indexer.items():
for dimension, position in indexer.items():
dims = [
dim
for dim in self.combine_dims
if dim.name == dimkey.name and dim.operation == dimkey.operation
combine_dim
for combine_dim in self.combine_dims
if combine_dim.dimension == dimension
]
if len(dims) != 1:
raise KeyError(r"Could not valid combine_dim for indexer {idx_key}")
raise KeyError(f"Could not valid combine_dim for dimension {dimension}")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My personal preference is {dimension!r}.

dim = dims[0]
format_function_kwargs[dim.name] = dim.keys[dimval.position]
format_function_kwargs[dim.name] = dim.keys[position.value]
fname = self.format_function(**format_function_kwargs)
return fname

def __iter__(self) -> Iterator[Index]:
"""Iterate over all keys in the pattern."""
for val in product(*[range(n) for n in self.shape]):
index = Index(
{DimKey(op.name, op.operation): DimVal(v) for op, v in zip(self.combine_dims, val)}
{
Dimension(op.name, op.operation): Position(v)
for op, v in zip(self.combine_dims, val)
}
)
yield index

Expand Down
18 changes: 9 additions & 9 deletions pangeo_forge_recipes/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from .aggregation import XarraySchema, dataset_to_schema, schema_to_zarr
from .combiners import CombineXarraySchemas
from .openers import open_url, open_with_xarray
from .patterns import CombineOp, DimKey, FileType, Index, augment_index_with_start_stop
from .patterns import CombineOp, Dimension, FileType, Index, augment_index_with_start_stop
from .storage import CacheFSSpecTarget, FSSpecTarget
from .writers import store_dataset_fragment

Expand Down Expand Up @@ -124,10 +124,10 @@ def expand(self, pcoll):
)


def _nest_dim(item: Indexed[T], dim_key: DimKey) -> Indexed[Indexed[T]]:
def _nest_dim(item: Indexed[T], dimension: Dimension) -> Indexed[Indexed[T]]:
index, value = item
inner_index = Index({dim_key: index[dim_key]})
outer_index = Index({dk: index[dk] for dk in index if dk != dim_key})
inner_index = Index({dimension: index[dimension]})
outer_index = Index({dk: index[dk] for dk in index if dk != dimension})
return outer_index, (inner_index, value)


Expand All @@ -136,13 +136,13 @@ class _NestDim(beam.PTransform):
"""Prepare a collection for grouping by transforming an Index into a nested
Tuple of Indexes.

:param dim_key: The dimension to nest
:param dimension: The dimension to nest
"""

dim_key: DimKey
dimension: Dimension

def expand(self, pcoll):
return pcoll | beam.Map(_nest_dim, dim_key=self.dim_key)
return pcoll | beam.Map(_nest_dim, dimension=self.dimension)


@dataclass
Expand All @@ -159,7 +159,7 @@ class DetermineSchema(beam.PTransform):
:param combine_dims: The dimensions to combine
"""

combine_dims: List[DimKey]
combine_dims: List[Dimension]

def expand(self, pcoll: beam.PCollection) -> beam.PCollection:
cdims = self.combine_dims.copy()
Expand Down Expand Up @@ -250,7 +250,7 @@ class StoreToZarr(beam.PTransform):

# TODO: make it so we don't have to explictly specify combine_dims
# Could be inferred from the pattern instead
combine_dims: List[DimKey]
combine_dims: List[Dimension]
target_url: str
target_chunks: Dict[str, int] = field(default_factory=dict)

Expand Down
73 changes: 73 additions & 0 deletions pangeo_forge_recipes/types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
from dataclasses import dataclass
from enum import Enum
from typing import Dict, Optional


class CombineOp(Enum):
"""Used to uniquely identify different combine operations across Pangeo Forge Recipes."""

MERGE = 1
CONCAT = 2
SUBSET = 3


@dataclass(frozen=True, order=True)
class Dimension:
"""
:param name: The name of the dimension we are combining over.
:param operation: What type of combination this is (merge or concat)
"""

name: str
operation: CombineOp


@dataclass(frozen=True, order=True)
class Position:
"""
:param indexed: If True, this position represents an offset within a dataset
If False, it is a position within a sequence.
"""

value: int
# TODO: consider using a ClassVar here
indexed: bool = False


@dataclass(frozen=True, order=True)
class IndexedPosition(Position):
indexed: bool = True


class Index(Dict[Dimension, Position]):
"""An Index is a special sort of dictionary which describes a position within
a multidimensional set.

- The key is a :class:`Dimension` which tells us which dimension we are addressing.
- The value is a :class:`Position` which tells us where we are within that dimension.

This object is hashable and deterministically serializable.
"""

def __hash__(self):
return hash(tuple(self.__getstate__()))

def __getstate__(self):
return sorted(self.items())

def __setstate__(self, state):
self.__init__({k: v for k, v in state})

def find_concat_dim(self, dim_name: str) -> Optional[Dimension]:
possible_concat_dims = [
d for d in self if (d.name == dim_name and d.operation == CombineOp.CONCAT)
]
if len(possible_concat_dims) > 1:
raise ValueError(
f"Found {len(possible_concat_dims)} concat dims named {dim_name} "
f"in the index {self}."
)
elif len(possible_concat_dims) == 0:
return None
else:
return possible_concat_dims[0]
16 changes: 9 additions & 7 deletions pangeo_forge_recipes/writers.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,14 @@
def _region_for(var: xr.Variable, index: Index) -> Tuple[slice, ...]:
region_slice = []
for dim, dimsize in var.sizes.items():
concat_dim_val = index.find_concat_dim(dim)
if concat_dim_val:
concat_dimension = index.find_concat_dim(dim)
if concat_dimension:
# we are concatenating over this dimension
assert concat_dim_val.start is not None
assert concat_dim_val.stop == concat_dim_val.start + dimsize
region_slice.append(slice(concat_dim_val.start, concat_dim_val.stop))
position = index[concat_dimension]
assert position.indexed
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would be useful to have an assert message, maybe with some message to help end-users debug.

start = position.value
stop = start + dimsize
region_slice.append(slice(start, stop))
else:
# we are writing the entire dimension
region_slice.append(slice(None))
Expand All @@ -36,15 +38,15 @@ def _store_data(vname: str, var: xr.Variable, index: Index, zgroup: zarr.Group)

def _is_first_item(index):
for _, v in index.items():
if v.position > 0:
if v.value > 0:
return False
return True


def _is_first_in_merge_dim(index):
for k, v in index.items():
if k.operation == CombineOp.MERGE:
if v.position > 0:
if v.value > 0:
return False
return True

Expand Down
Loading