From 631e46401ca7bdb6dba22375792e0480c255d8b5 Mon Sep 17 00:00:00 2001 From: Ryan Abernathey Date: Sat, 13 Aug 2022 16:51:54 +0300 Subject: [PATCH 01/13] first pass working rechunking --- pangeo_forge_recipes/patterns.py | 2 +- pangeo_forge_recipes/rechunking.py | 62 ++++++++++++++++++++++++++++++ tests/test_rechunking.py | 41 ++++++++++++++++++++ 3 files changed, 104 insertions(+), 1 deletion(-) create mode 100644 pangeo_forge_recipes/rechunking.py create mode 100644 tests/test_rechunking.py diff --git a/pangeo_forge_recipes/patterns.py b/pangeo_forge_recipes/patterns.py index c813b5fe..d234b8a5 100644 --- a/pangeo_forge_recipes/patterns.py +++ b/pangeo_forge_recipes/patterns.py @@ -105,7 +105,7 @@ def __getstate__(self): def __setstate__(self, state): self.__init__({k: v for k, v in state}) - def find_concat_dim(self, dim_name: str): + def find_concat_dim(self, dim_name: str) -> Optional[DimVal]: possible_concat_dims = [ d for d in self if (d.name == dim_name and d.operation == CombineOp.CONCAT) ] diff --git a/pangeo_forge_recipes/rechunking.py b/pangeo_forge_recipes/rechunking.py new file mode 100644 index 00000000..cc64a63b --- /dev/null +++ b/pangeo_forge_recipes/rechunking.py @@ -0,0 +1,62 @@ +import itertools +from typing import Dict, Tuple + +import xarray as xr + +from .chunk_grid import ChunkGrid +from .patterns import CombineOp, DimKey, DimVal, Index + +ChunkDimDict = Dict[str, Tuple[int, int]] + +# group keys are a tuple of tuples like (("lon", 1), ("time", 0)) +# the ints are chunk indexes +# code should aways sort the key before emitting it +GroupKey = Tuple[Tuple[str, int], ...] + + +def split_fragment(fragment: Tuple[Index, xr.Dataset], target_chunks_and_dims: ChunkDimDict): + index, ds = fragment + chunk_grid = ChunkGrid.from_uniform_grid(target_chunks_and_dims) + + fragment_slices = {} # type: Dict[str, slice] + for dim in target_chunks_and_dims: + concat_dim_key = index.find_concat_dim(dim) + if concat_dim_key: + # this dimension is present in the fragment as a concat dim + dim_slice = slice(concat_dim_key.start, concat_dim_key.stop) + else: + # If there is a target_chunk that is NOT present as a concat_dim in the fragment, + # then we can assume that the entire span of that dimension is present in the dataset + # This would arise e.g. when decimating a contiguous dimension + dim_slice = slice(0, ds.dims[dim]) + fragment_slices[dim] = dim_slice + + target_chunk_slices = chunk_grid.array_slice_to_chunk_slice(fragment_slices) + + all_chunks = itertools.product( + *( + [(dim, n) for n in range(chunk_slice.start, chunk_slice.stop)] + for dim, chunk_slice in target_chunk_slices.items() + ) + ) + + for target_chunk_group in all_chunks: + # now we need to figure out which piece of the fragment belongs in which chunk + chunk_array_slices = chunk_grid.chunk_index_to_array_slice(dict(target_chunk_group)) + sub_fragment_indexer = {} # passed to ds.isel + sub_fragment_index = Index() + for dim, chunk_slice in chunk_array_slices.items(): + fragment_slice = fragment_slices[dim] + start = max(chunk_slice.start, fragment_slice.start) + stop = min(chunk_slice.stop, fragment_slice.stop) + sub_fragment_indexer[dim] = slice( + start - fragment_slice.start, stop - fragment_slice.start + ) + dim_key = DimKey(dim, CombineOp.CONCAT) + # I am getting the original "position" value from the original index + # Not sure if this makes sense. There is no way to know the actual position here + # without knowing all the previous subfragments + original_position = getattr(index.get(dim_key), "position", 0) + sub_fragment_index[dim_key] = DimVal(original_position, start, stop) + sub_fragment_ds = ds.isel(**sub_fragment_indexer) + yield tuple(sorted(target_chunk_group)), (sub_fragment_index, sub_fragment_ds) diff --git a/tests/test_rechunking.py b/tests/test_rechunking.py new file mode 100644 index 00000000..b6f80f8d --- /dev/null +++ b/tests/test_rechunking.py @@ -0,0 +1,41 @@ +import math + +import pytest +import xarray as xr + +from pangeo_forge_recipes.patterns import CombineOp, DimKey, DimVal, Index +from pangeo_forge_recipes.rechunking import split_fragment + +from .data_generation import make_ds + + +@pytest.mark.parametrize("time_chunks", [1, 3, 5, 10, 11]) +def test_split_fragment(time_chunks): + + nt = 10 + ds = make_ds(nt=nt) # this represents a single dataset fragment + + # in this case we have a single fragment which overlaps two + # target chunks; it needs to be split into two pieces + target_chunks_and_dims = {"time": (time_chunks, nt)} + ds_fragment = ds.isel(time=slice(0, nt)) # the whole thing + dim_key = DimKey("time", CombineOp.CONCAT) + index = Index({dim_key: DimVal(0, 0, nt)}) + + all_splits = list( + split_fragment((index, ds_fragment), target_chunks_and_dims=target_chunks_and_dims) + ) + + expected_chunks = math.ceil(nt / time_chunks) + assert len(all_splits) == expected_chunks + + group_keys = [item[0] for item in all_splits] + new_indexes = [item[1][0] for item in all_splits] + new_datasets = [item[1][1] for item in all_splits] + + assert group_keys == [(("time", n),) for n in range(expected_chunks)] + + for n in range(expected_chunks): + start, stop = time_chunks * n, min(time_chunks * (n + 1), nt) + assert new_indexes[n] == Index({dim_key: DimVal(0, start, stop)}) + xr.testing.assert_equal(new_datasets[n], ds.isel(time=slice(start, stop))) From a95cc36edbf1b858dc46b0d9f957a5617ec61eae Mon Sep 17 00:00:00 2001 From: Ryan Abernathey Date: Sat, 13 Aug 2022 17:58:16 +0300 Subject: [PATCH 02/13] new test wip --- tests/test_rechunking.py | 29 +++++++++++++++-------------- 1 file changed, 15 insertions(+), 14 deletions(-) diff --git a/tests/test_rechunking.py b/tests/test_rechunking.py index b6f80f8d..1e0650e8 100644 --- a/tests/test_rechunking.py +++ b/tests/test_rechunking.py @@ -9,22 +9,19 @@ from .data_generation import make_ds +@pytest.mark.parametrize("offset", [0, 5]) # hypothetical offset of this fragment @pytest.mark.parametrize("time_chunks", [1, 3, 5, 10, 11]) -def test_split_fragment(time_chunks): +def test_split_fragment(time_chunks, offset): + + nt_total = 20 # the total size of the hypothetical dataset + target_chunks_and_dims = {"time": (time_chunks, nt_total)} nt = 10 ds = make_ds(nt=nt) # this represents a single dataset fragment - - # in this case we have a single fragment which overlaps two - # target chunks; it needs to be split into two pieces - target_chunks_and_dims = {"time": (time_chunks, nt)} - ds_fragment = ds.isel(time=slice(0, nt)) # the whole thing dim_key = DimKey("time", CombineOp.CONCAT) - index = Index({dim_key: DimVal(0, 0, nt)}) + index = Index({dim_key: DimVal(0, offset, offset + nt)}) - all_splits = list( - split_fragment((index, ds_fragment), target_chunks_and_dims=target_chunks_and_dims) - ) + all_splits = list(split_fragment((index, ds), target_chunks_and_dims=target_chunks_and_dims)) expected_chunks = math.ceil(nt / time_chunks) assert len(all_splits) == expected_chunks @@ -33,9 +30,13 @@ def test_split_fragment(time_chunks): new_indexes = [item[1][0] for item in all_splits] new_datasets = [item[1][1] for item in all_splits] - assert group_keys == [(("time", n),) for n in range(expected_chunks)] - for n in range(expected_chunks): - start, stop = time_chunks * n, min(time_chunks * (n + 1), nt) - assert new_indexes[n] == Index({dim_key: DimVal(0, start, stop)}) + chunk_number = offset // time_chunks + n + assert group_keys[n] == (("time", chunk_number),) + chunk_start = time_chunks * chunk_number + chunk_stop = min(time_chunks * (chunk_number + 1), nt_total) + fragment_start = max(chunk_start, offset) + fragment_stop = min(chunk_stop, fragment_start + time_chunks) + assert new_indexes[n] == Index({dim_key: DimVal(0, fragment_start, fragment_stop)}) + start, stop = fragment_start - offset, fragment_stop - offset xr.testing.assert_equal(new_datasets[n], ds.isel(time=slice(start, stop))) From d459db3cee0e119f0997722dc5e139f911fd4fa0 Mon Sep 17 00:00:00 2001 From: Ryan Abernathey Date: Sat, 13 Aug 2022 18:35:10 +0300 Subject: [PATCH 03/13] improve rechuking test --- tests/test_rechunking.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/tests/test_rechunking.py b/tests/test_rechunking.py index 1e0650e8..36e8ac4a 100644 --- a/tests/test_rechunking.py +++ b/tests/test_rechunking.py @@ -1,5 +1,3 @@ -import math - import pytest import xarray as xr @@ -23,20 +21,21 @@ def test_split_fragment(time_chunks, offset): all_splits = list(split_fragment((index, ds), target_chunks_and_dims=target_chunks_and_dims)) - expected_chunks = math.ceil(nt / time_chunks) - assert len(all_splits) == expected_chunks - group_keys = [item[0] for item in all_splits] new_indexes = [item[1][0] for item in all_splits] new_datasets = [item[1][1] for item in all_splits] - for n in range(expected_chunks): + for n in range(len(all_splits)): chunk_number = offset // time_chunks + n assert group_keys[n] == (("time", chunk_number),) chunk_start = time_chunks * chunk_number chunk_stop = min(time_chunks * (chunk_number + 1), nt_total) fragment_start = max(chunk_start, offset) - fragment_stop = min(chunk_stop, fragment_start + time_chunks) + fragment_stop = min(chunk_stop, fragment_start + time_chunks, offset + nt) assert new_indexes[n] == Index({dim_key: DimVal(0, fragment_start, fragment_stop)}) start, stop = fragment_start - offset, fragment_stop - offset xr.testing.assert_equal(new_datasets[n], ds.isel(time=slice(start, stop))) + + # make sure we got the whole dataset back + ds_concat = xr.concat(new_datasets, "time") + xr.testing.assert_equal(ds, ds_concat) From 4ca307626adb0ab3cf4301398f53e8b3d64fe2e5 Mon Sep 17 00:00:00 2001 From: Ryan Abernathey Date: Sat, 13 Aug 2022 18:51:25 +0300 Subject: [PATCH 04/13] multidim rechunking test --- tests/test_rechunking.py | 43 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 43 insertions(+) diff --git a/tests/test_rechunking.py b/tests/test_rechunking.py index 36e8ac4a..68b266ce 100644 --- a/tests/test_rechunking.py +++ b/tests/test_rechunking.py @@ -10,6 +10,7 @@ @pytest.mark.parametrize("offset", [0, 5]) # hypothetical offset of this fragment @pytest.mark.parametrize("time_chunks", [1, 3, 5, 10, 11]) def test_split_fragment(time_chunks, offset): + """A thorough test of 1D splitting logic that should cover all major edge cases.""" nt_total = 20 # the total size of the hypothetical dataset target_chunks_and_dims = {"time": (time_chunks, nt_total)} @@ -39,3 +40,45 @@ def test_split_fragment(time_chunks, offset): # make sure we got the whole dataset back ds_concat = xr.concat(new_datasets, "time") xr.testing.assert_equal(ds, ds_concat) + + +def test_split_multidim(): + """A simple test that checks whether splitting logic is applied correctly + for multiple dimensions.""" + + nt = 2 + ds = make_ds(nt=nt) + nlat = ds.dims["lat"] + dim_key = DimKey("time", CombineOp.CONCAT) + index = Index({dim_key: DimVal(0, 0, nt)}) + + time_chunks = 1 + lat_chunks = nlat // 2 + target_chunks_and_dims = {"time": (time_chunks, nt), "lat": (lat_chunks, nlat)} + + all_splits = list(split_fragment((index, ds), target_chunks_and_dims=target_chunks_and_dims)) + + group_keys = [item[0] for item in all_splits] + + assert group_keys == [ + (("lat", 0), ("time", 0)), + (("lat", 1), ("time", 0)), + (("lat", 0), ("time", 1)), + (("lat", 1), ("time", 1)), + ] + + for group_key, (fragment_index, fragment_ds) in all_splits: + n_lat_chunk = group_key[0][1] + n_time_chunk = group_key[1][1] + time_start, time_stop = n_time_chunk * time_chunks, (n_time_chunk + 1) * time_chunks + lat_start, lat_stop = n_lat_chunk * lat_chunks, (n_lat_chunk + 1) * lat_chunks + expected_index = Index( + { + DimKey("time", CombineOp.CONCAT): DimVal(0, time_start, time_stop), + DimKey("lat", CombineOp.CONCAT): DimVal(0, lat_start, lat_stop), + } + ) + assert fragment_index == expected_index + xr.testing.assert_equal( + fragment_ds, ds.isel(time=slice(time_start, time_stop), lat=slice(lat_start, lat_stop)) + ) From afa6e975933f0a4c12f8115ed0ca10f0c09ed1d7 Mon Sep 17 00:00:00 2001 From: Ryan Abernathey Date: Sun, 14 Aug 2022 07:14:04 +0300 Subject: [PATCH 05/13] change find_concat_dim behavior --- pangeo_forge_recipes/patterns.py | 5 ++--- pangeo_forge_recipes/rechunking.py | 3 ++- pangeo_forge_recipes/writers.py | 5 +++-- tests/test_patterns.py | 12 ++++++------ 4 files changed, 13 insertions(+), 12 deletions(-) diff --git a/pangeo_forge_recipes/patterns.py b/pangeo_forge_recipes/patterns.py index d234b8a5..9a78eacd 100644 --- a/pangeo_forge_recipes/patterns.py +++ b/pangeo_forge_recipes/patterns.py @@ -105,7 +105,7 @@ def __getstate__(self): def __setstate__(self, state): self.__init__({k: v for k, v in state}) - def find_concat_dim(self, dim_name: str) -> Optional[DimVal]: + def find_concat_dim(self, dim_name: str) -> Optional[DimKey]: possible_concat_dims = [ d for d in self if (d.name == dim_name and d.operation == CombineOp.CONCAT) ] @@ -117,8 +117,7 @@ def find_concat_dim(self, dim_name: str) -> Optional[DimVal]: elif len(possible_concat_dims) == 0: return None else: - key = possible_concat_dims[0] - return self[key] + return possible_concat_dims[0] CombineDim = Union[MergeDim, ConcatDim] diff --git a/pangeo_forge_recipes/rechunking.py b/pangeo_forge_recipes/rechunking.py index cc64a63b..c60146af 100644 --- a/pangeo_forge_recipes/rechunking.py +++ b/pangeo_forge_recipes/rechunking.py @@ -23,7 +23,8 @@ def split_fragment(fragment: Tuple[Index, xr.Dataset], target_chunks_and_dims: C concat_dim_key = index.find_concat_dim(dim) if concat_dim_key: # this dimension is present in the fragment as a concat dim - dim_slice = slice(concat_dim_key.start, concat_dim_key.stop) + concat_dim_val = index[concat_dim_key] + dim_slice = slice(concat_dim_val.start, concat_dim_val.stop) else: # If there is a target_chunk that is NOT present as a concat_dim in the fragment, # then we can assume that the entire span of that dimension is present in the dataset diff --git a/pangeo_forge_recipes/writers.py b/pangeo_forge_recipes/writers.py index 352238eb..31aa74bd 100644 --- a/pangeo_forge_recipes/writers.py +++ b/pangeo_forge_recipes/writers.py @@ -10,9 +10,10 @@ def _region_for(var: xr.Variable, index: Index) -> Tuple[slice, ...]: region_slice = [] for dim, dimsize in var.sizes.items(): - concat_dim_val = index.find_concat_dim(dim) - if concat_dim_val: + concat_dim_key = index.find_concat_dim(dim) + if concat_dim_key: # we are concatenating over this dimension + concat_dim_val = index[concat_dim_key] assert concat_dim_val.start is not None assert concat_dim_val.stop == concat_dim_val.start + dimsize region_slice.append(slice(concat_dim_val.start, concat_dim_val.stop)) diff --git a/tests/test_patterns.py b/tests/test_patterns.py index b90e7dbc..68f679de 100644 --- a/tests/test_patterns.py +++ b/tests/test_patterns.py @@ -121,19 +121,19 @@ def test_file_pattern_concat_merge(runtime_secrets, pickle, concat_merge_pattern assert fp.nitems_per_input == {"time": None} assert fp.concat_sequence_lens == {"time": None} assert len(list(fp)) == 6 - for key in fp: - concat_val = key.find_concat_dim("time") - assert key.find_concat_dim("foobar") is None - for dimkey, dimval in key.items(): + for index in fp: + concat_dim_key = index.find_concat_dim("time") + assert index.find_concat_dim("foobar") is None + for dimkey, dimval in index.items(): if dimkey.name == "time": assert dimkey.operation == CombineOp.CONCAT time_val = times[dimval.position] - assert dimval == concat_val + assert dimval == index[concat_dim_key] if dimkey.name == "variable": assert dimkey.operation == CombineOp.MERGE variable_val = varnames[dimval.position] expected_fname = format_function(time=time_val, variable=variable_val) - assert fp[key] == expected_fname + assert fp[index] == expected_fname if "fsspec_open_kwargs" in kwargs.keys(): assert fp.file_type != FileType.opendap From e82561d47db633a7cd8e65df7e013b793ffb25b3 Mon Sep 17 00:00:00 2001 From: Ryan Abernathey Date: Sun, 14 Aug 2022 07:31:37 +0300 Subject: [PATCH 06/13] make split_fragment preserve index elements not involved in the rechunking --- pangeo_forge_recipes/rechunking.py | 5 ++++- tests/test_rechunking.py | 17 +++++++++++++++-- 2 files changed, 19 insertions(+), 3 deletions(-) diff --git a/pangeo_forge_recipes/rechunking.py b/pangeo_forge_recipes/rechunking.py index c60146af..dce3c5a7 100644 --- a/pangeo_forge_recipes/rechunking.py +++ b/pangeo_forge_recipes/rechunking.py @@ -19,12 +19,14 @@ def split_fragment(fragment: Tuple[Index, xr.Dataset], target_chunks_and_dims: C chunk_grid = ChunkGrid.from_uniform_grid(target_chunks_and_dims) fragment_slices = {} # type: Dict[str, slice] + keys_to_skip = [] # type: list[DimKey] for dim in target_chunks_and_dims: concat_dim_key = index.find_concat_dim(dim) if concat_dim_key: # this dimension is present in the fragment as a concat dim concat_dim_val = index[concat_dim_key] dim_slice = slice(concat_dim_val.start, concat_dim_val.stop) + keys_to_skip.append(concat_dim_key) else: # If there is a target_chunk that is NOT present as a concat_dim in the fragment, # then we can assume that the entire span of that dimension is present in the dataset @@ -45,7 +47,8 @@ def split_fragment(fragment: Tuple[Index, xr.Dataset], target_chunks_and_dims: C # now we need to figure out which piece of the fragment belongs in which chunk chunk_array_slices = chunk_grid.chunk_index_to_array_slice(dict(target_chunk_group)) sub_fragment_indexer = {} # passed to ds.isel - sub_fragment_index = Index() + # initialize the new index with the items we want to keep from the original index + sub_fragment_index = Index({k: v for k, v in index.items() if k not in keys_to_skip}) for dim, chunk_slice in chunk_array_slices.items(): fragment_slice = fragment_slices[dim] start = max(chunk_slice.start, fragment_slice.start) diff --git a/tests/test_rechunking.py b/tests/test_rechunking.py index 68b266ce..416b9f76 100644 --- a/tests/test_rechunking.py +++ b/tests/test_rechunking.py @@ -18,7 +18,13 @@ def test_split_fragment(time_chunks, offset): nt = 10 ds = make_ds(nt=nt) # this represents a single dataset fragment dim_key = DimKey("time", CombineOp.CONCAT) - index = Index({dim_key: DimVal(0, offset, offset + nt)}) + + extra_indexes = [ + (DimKey("foo", CombineOp.CONCAT), DimVal(0)), + (DimKey("bar", CombineOp.MERGE), DimVal(1)), + ] + + index = Index([(dim_key, DimVal(0, offset, offset + nt))] + extra_indexes) all_splits = list(split_fragment((index, ds), target_chunks_and_dims=target_chunks_and_dims)) @@ -33,7 +39,10 @@ def test_split_fragment(time_chunks, offset): chunk_stop = min(time_chunks * (chunk_number + 1), nt_total) fragment_start = max(chunk_start, offset) fragment_stop = min(chunk_stop, fragment_start + time_chunks, offset + nt) - assert new_indexes[n] == Index({dim_key: DimVal(0, fragment_start, fragment_stop)}) + # other dimensions in the index should be passed through unchanged + assert new_indexes[n] == Index( + [(dim_key, DimVal(0, fragment_start, fragment_stop))] + extra_indexes + ) start, stop = fragment_start - offset, fragment_stop - offset xr.testing.assert_equal(new_datasets[n], ds.isel(time=slice(start, stop))) @@ -82,3 +91,7 @@ def test_split_multidim(): xr.testing.assert_equal( fragment_ds, ds.isel(time=slice(time_start, time_stop), lat=slice(lat_start, lat_stop)) ) + + +def test_combine_fragments(): + """Basically just the inverse of the ``test_split_multidim`` test.""" From 69d31866940eb773ede87f8e6e779bc2701896ea Mon Sep 17 00:00:00 2001 From: Ryan Abernathey Date: Sun, 14 Aug 2022 09:28:38 +0300 Subject: [PATCH 07/13] first working combine_fragments --- pangeo_forge_recipes/rechunking.py | 49 +++++++++++++++++++++++++++++- tests/test_rechunking.py | 24 +++++++++++++-- 2 files changed, 70 insertions(+), 3 deletions(-) diff --git a/pangeo_forge_recipes/rechunking.py b/pangeo_forge_recipes/rechunking.py index dce3c5a7..5a6a4a39 100644 --- a/pangeo_forge_recipes/rechunking.py +++ b/pangeo_forge_recipes/rechunking.py @@ -1,6 +1,7 @@ import itertools -from typing import Dict, Tuple +from typing import Dict, List, Tuple +import numpy as np import xarray as xr from .chunk_grid import ChunkGrid @@ -48,6 +49,7 @@ def split_fragment(fragment: Tuple[Index, xr.Dataset], target_chunks_and_dims: C chunk_array_slices = chunk_grid.chunk_index_to_array_slice(dict(target_chunk_group)) sub_fragment_indexer = {} # passed to ds.isel # initialize the new index with the items we want to keep from the original index + # TODO: think about whether we want to always rechunk concat dims sub_fragment_index = Index({k: v for k, v in index.items() if k not in keys_to_skip}) for dim, chunk_slice in chunk_array_slices.items(): fragment_slice = fragment_slices[dim] @@ -64,3 +66,48 @@ def split_fragment(fragment: Tuple[Index, xr.Dataset], target_chunks_and_dims: C sub_fragment_index[dim_key] = DimVal(original_position, start, stop) sub_fragment_ds = ds.isel(**sub_fragment_indexer) yield tuple(sorted(target_chunk_group)), (sub_fragment_index, sub_fragment_ds) + + +def _sort_index_key(item): + index = item[0] + return tuple(index.items()) + + +def combine_fragments(fragments: List[Tuple[Index, xr.Dataset]]) -> Tuple[Index, xr.Dataset]: + # we are combining over all the concat dims found in the indexes + # first check indexes for consistency + fragments.sort(key=_sort_index_key) # this should sort by index + all_indexes = [item[0] for item in fragments] + first_index = all_indexes[0] + dim_keys = tuple(first_index) + if not all([tuple(index) == dim_keys for index in all_indexes]): + raise ValueError( + f"Cannot combine fragments for elements with different combine dims: {all_indexes}" + ) + concat_dims = [dim_key for dim_key in dim_keys if dim_key.operation == CombineOp.CONCAT] + dim_names_and_vals = { + dim_key.name: [index[dim_key] for index in all_indexes] for dim_key in concat_dims + } + index_combined = Index() + for dim, dim_vals in dim_names_and_vals.items(): + # check for contiguity + starts = [dim_val.start for dim_val in dim_vals][1:] + stops = [dim_val.stop for dim_val in dim_vals][:-1] + if not starts == stops: + raise ValueError( + f"Index starts and stops are not consistent for concat_dim {dim}: {dim_vals}" + ) + # Position is unneeded at this point, but we still have to provide it + # This API probably needs to change + combined_dim_val = DimVal(dim_vals[0].position, dim_vals[0].start, dim_vals[-1].stop) + index_combined[DimKey(dim, CombineOp.CONCAT)] = combined_dim_val + # now create the nested dataset structure we need + shape = [len(dim_vals) for dim_vales in dim_names_and_vals.items()] + # some tricky workarounds to put xarray datasets into a nested list + all_datasets = np.empty(shape, dtype="O").ravel() + for n, fragment in enumerate(fragments): + all_datasets[n] = fragment[1] + dsets_to_concat = all_datasets.reshape(shape).tolist() + ds_combined = xr.combine_nested(dsets_to_concat, concat_dim=list(dim_names_and_vals)) + + return index_combined, ds_combined diff --git a/tests/test_rechunking.py b/tests/test_rechunking.py index 416b9f76..34d72a86 100644 --- a/tests/test_rechunking.py +++ b/tests/test_rechunking.py @@ -2,7 +2,7 @@ import xarray as xr from pangeo_forge_recipes.patterns import CombineOp, DimKey, DimVal, Index -from pangeo_forge_recipes.rechunking import split_fragment +from pangeo_forge_recipes.rechunking import combine_fragments, split_fragment from .data_generation import make_ds @@ -94,4 +94,24 @@ def test_split_multidim(): def test_combine_fragments(): - """Basically just the inverse of the ``test_split_multidim`` test.""" + """The function applied after GroupBy to combine fragments into a single chunk. + All concat dims that appear more than once are combined. + """ + + nt = 4 + time_chunk = 2 + ds = make_ds(nt=nt) + + fragments = [] + dim_key = DimKey("time", CombineOp.CONCAT) + for nfrag, start in enumerate(range(0, nt, time_chunk)): + stop = min(start + time_chunk, nt) + # we are ignoring position (first item) at this point + index_frag = Index({dim_key: DimVal(0, start, stop)}) + ds_frag = ds.isel(time=slice(start, stop)) + fragments.append((index_frag, ds_frag)) + + index, ds_comb = combine_fragments(fragments) + + assert index == Index({dim_key: DimVal(0, 0, nt)}) + xr.testing.assert_equal(ds, ds_comb) From f9a8e14be205001f7d38d08428358c7e591ba3c5 Mon Sep 17 00:00:00 2001 From: Ryan Abernathey Date: Sun, 14 Aug 2022 10:05:36 +0300 Subject: [PATCH 08/13] parametrized combine test --- pangeo_forge_recipes/rechunking.py | 20 +++++++++++++++++--- tests/test_rechunking.py | 6 +++--- 2 files changed, 20 insertions(+), 6 deletions(-) diff --git a/pangeo_forge_recipes/rechunking.py b/pangeo_forge_recipes/rechunking.py index 5a6a4a39..b8370429 100644 --- a/pangeo_forge_recipes/rechunking.py +++ b/pangeo_forge_recipes/rechunking.py @@ -85,11 +85,16 @@ def combine_fragments(fragments: List[Tuple[Index, xr.Dataset]]) -> Tuple[Index, f"Cannot combine fragments for elements with different combine dims: {all_indexes}" ) concat_dims = [dim_key for dim_key in dim_keys if dim_key.operation == CombineOp.CONCAT] + other_dims = [dim_key for dim_key in dim_keys if dim_key.operation != CombineOp.CONCAT] + # initialize new index with non-concat dims + index_combined = Index({dim: first_index[dim] for dim in other_dims}) dim_names_and_vals = { dim_key.name: [index[dim_key] for index in all_indexes] for dim_key in concat_dims } - index_combined = Index() for dim, dim_vals in dim_names_and_vals.items(): + for dim_val in dim_vals: + if dim_val.start is None or dim_val.stop is None: + raise ValueError("Can only comined indexed fragments.") # check for contiguity starts = [dim_val.start for dim_val in dim_vals][1:] stops = [dim_val.stop for dim_val in dim_vals][:-1] @@ -102,12 +107,21 @@ def combine_fragments(fragments: List[Tuple[Index, xr.Dataset]]) -> Tuple[Index, combined_dim_val = DimVal(dim_vals[0].position, dim_vals[0].start, dim_vals[-1].stop) index_combined[DimKey(dim, CombineOp.CONCAT)] = combined_dim_val # now create the nested dataset structure we need - shape = [len(dim_vals) for dim_vales in dim_names_and_vals.items()] + shape = tuple(len(dim_vals) for dim_vals in dim_names_and_vals.values()) + expected_dims = { + dim_name: (dim_vals[-1].stop - dim_vals[0].start) # type: ignore + for dim_name, dim_vals in dim_names_and_vals.items() + } # some tricky workarounds to put xarray datasets into a nested list all_datasets = np.empty(shape, dtype="O").ravel() for n, fragment in enumerate(fragments): all_datasets[n] = fragment[1] dsets_to_concat = all_datasets.reshape(shape).tolist() ds_combined = xr.combine_nested(dsets_to_concat, concat_dim=list(dim_names_and_vals)) - + actual_dims = {dim: ds_combined.dims[dim] for dim in expected_dims} + if actual_dims != expected_dims: + raise ValueError( + f"Combined dataset dims {actual_dims} not the same as those expected" + f"from the index {expected_dims}" + ) return index_combined, ds_combined diff --git a/tests/test_rechunking.py b/tests/test_rechunking.py index 34d72a86..3487772d 100644 --- a/tests/test_rechunking.py +++ b/tests/test_rechunking.py @@ -93,13 +93,13 @@ def test_split_multidim(): ) -def test_combine_fragments(): +@pytest.mark.parametrize("time_chunk", [1, 2, 3, 5, 10]) +def test_combine_fragments(time_chunk): """The function applied after GroupBy to combine fragments into a single chunk. All concat dims that appear more than once are combined. """ - nt = 4 - time_chunk = 2 + nt = 10 ds = make_ds(nt=nt) fragments = [] From 957e3bd978f8073931673486eb4aaba4525dbf32 Mon Sep 17 00:00:00 2001 From: Ryan Abernathey Date: Sun, 21 Aug 2022 14:07:12 +0200 Subject: [PATCH 09/13] add a bunch of comments to rechunking.py --- pangeo_forge_recipes/rechunking.py | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/pangeo_forge_recipes/rechunking.py b/pangeo_forge_recipes/rechunking.py index b8370429..4895fbb6 100644 --- a/pangeo_forge_recipes/rechunking.py +++ b/pangeo_forge_recipes/rechunking.py @@ -16,10 +16,20 @@ def split_fragment(fragment: Tuple[Index, xr.Dataset], target_chunks_and_dims: ChunkDimDict): + """Split a single indexed dataset fragment into sub-fragments, according to the + specified target chunks + + :param fragment: the indexed fragment. The index must have ``start`` and ``stop`` set. + :param target_chunks_and_dims: mapping from dimension name to a tuple of (chunksize, dimsize) + """ + index, ds = fragment chunk_grid = ChunkGrid.from_uniform_grid(target_chunks_and_dims) + # fragment_slices tells us where this fragement lies within the global dataset fragment_slices = {} # type: Dict[str, slice] + # keys_to_skip is used to track dimensions that are present in both + # concat dims and target chunks keys_to_skip = [] # type: list[DimKey] for dim in target_chunks_and_dims: concat_dim_key = index.find_concat_dim(dim) @@ -37,6 +47,8 @@ def split_fragment(fragment: Tuple[Index, xr.Dataset], target_chunks_and_dims: C target_chunk_slices = chunk_grid.array_slice_to_chunk_slice(fragment_slices) + # each chunk we are going to yield is indexed by a "target chunk group", + # a tuple of tuples of the form (("lat", 1), ("time", 0)) all_chunks = itertools.product( *( [(dim, n) for n in range(chunk_slice.start, chunk_slice.stop)] @@ -44,6 +56,7 @@ def split_fragment(fragment: Tuple[Index, xr.Dataset], target_chunks_and_dims: C ) ) + # this iteration yields new fragments, indexed by their target chunk group for target_chunk_group in all_chunks: # now we need to figure out which piece of the fragment belongs in which chunk chunk_array_slices = chunk_grid.chunk_index_to_array_slice(dict(target_chunk_group)) @@ -74,6 +87,13 @@ def _sort_index_key(item): def combine_fragments(fragments: List[Tuple[Index, xr.Dataset]]) -> Tuple[Index, xr.Dataset]: + """Combine multiple dataset fragments into a single fragment. + + Only combines concat dims; merge dims are not combined. + + :param fragments: indexed dataset fragments + """ + # we are combining over all the concat dims found in the indexes # first check indexes for consistency fragments.sort(key=_sort_index_key) # this should sort by index From 0ee506950c5fc89d530c60e74694a2532393fcda Mon Sep 17 00:00:00 2001 From: Ryan Abernathey Date: Sun, 21 Aug 2022 19:35:36 +0200 Subject: [PATCH 10/13] starting to refactor index types --- pangeo_forge_recipes/patterns.py | 119 ++++++++----------------------- pangeo_forge_recipes/types.py | 73 +++++++++++++++++++ tests/test_patterns.py | 32 ++++----- 3 files changed, 117 insertions(+), 107 deletions(-) create mode 100644 pangeo_forge_recipes/types.py diff --git a/pangeo_forge_recipes/patterns.py b/pangeo_forge_recipes/patterns.py index 9a78eacd..046ea81a 100644 --- a/pangeo_forge_recipes/patterns.py +++ b/pangeo_forge_recipes/patterns.py @@ -9,18 +9,22 @@ from typing import Any, Callable, ClassVar, Dict, Iterator, List, Optional, Sequence, Tuple, Union from .serialization import dict_drop_empty, dict_to_sha256 +from .types import CombineOp, Dimension, Index, IndexedPosition, Position -class CombineOp(Enum): - """Used to uniquely identify different combine operations across Pangeo Forge Recipes.""" +@dataclass(frozen=True) +class CombineDim: + name: str + operation: ClassVar[CombineOp] + keys: Sequence[Any] = field(repr=False) - MERGE = 1 - CONCAT = 2 - SUBSET = 3 + @property + def dimension(self): + return Dimension(self.name, self.operation) @dataclass(frozen=True) -class ConcatDim: +class ConcatDim(CombineDim): """Represents a concatenation operation across a dimension of a FilePattern. :param name: The name of the dimension we are concatenating over. For @@ -34,14 +38,12 @@ class ConcatDim: provide a fast path for recipes. """ - name: str # should match the actual dimension name - keys: Sequence[Any] = field(repr=False) nitems_per_file: Optional[int] = None operation: ClassVar[CombineOp] = CombineOp.CONCAT @dataclass(frozen=True) -class MergeDim: +class MergeDim(CombineDim): """Represents a merge operation across a dimension of a FilePattern. :param name: The name of the dimension we are are merging over. The actual @@ -52,88 +54,20 @@ class MergeDim: the file name. """ - name: str - keys: Sequence[Any] = field(repr=False) operation: ClassVar[CombineOp] = CombineOp.MERGE -# would it be simpler to just use a tuple? -@dataclass(frozen=True, order=True) -class DimKey: - """ - :param name: The name of the dimension we are combining over. - :param operation: What type of combination this is (merge or concat) - """ - - name: str - operation: CombineOp - - -@dataclass(frozen=True, order=True) -class DimVal: - """ - :param position: Where this item lies within the sequence. - :param start: Where the starting array index for the item. - :param stop: The ending array index for the item. - """ - - position: int - start: Optional[int] = None - stop: Optional[int] = None - - -# Alternative way of specifying type -# Index = dict[DimKey, DimVal] - - -class Index(Dict[DimKey, DimVal]): - """An Index is a special sort of dictionary which describes a position within - a multidimensional set. - - - The key is a :class:`DimKey` which tells us which dimension we are addressing. - - The value is a :class:`DimVal` which tells us where we are within that dimension. - - This object is hashable and deterministically serializable. - """ - - def __hash__(self): - return hash(tuple(self.__getstate__())) - - def __getstate__(self): - return sorted(self.items()) - - def __setstate__(self, state): - self.__init__({k: v for k, v in state}) - - def find_concat_dim(self, dim_name: str) -> Optional[DimKey]: - possible_concat_dims = [ - d for d in self if (d.name == dim_name and d.operation == CombineOp.CONCAT) - ] - if len(possible_concat_dims) > 1: - raise ValueError( - f"Found {len(possible_concat_dims)} concat dims named {dim_name} " - f"in the index {self}." - ) - elif len(possible_concat_dims) == 0: - return None - else: - return possible_concat_dims[0] - - -CombineDim = Union[MergeDim, ConcatDim] - - -def augment_index_with_start_stop(dim_val: DimVal, item_lens: List[int]) -> DimVal: +def augment_index_with_start_stop(position: Position, item_lens: List[int]) -> IndexedPosition: """Take an index _without_ start / stop and add them based on the lens defined in sequence_lens. :param index: The ``DimIndex`` instance to augment. :param item_lens: A list of integer lengths for all items in the sequence. """ - start = sum(item_lens[: dim_val.position]) - stop = start + item_lens[dim_val.position] - - return DimVal(dim_val.position, start, stop) + if position.indexed: + raise ValueError("This position is already indexed") + start = sum(item_lens[: position.value]) + return IndexedPosition(start) class AutoName(Enum): @@ -238,23 +172,23 @@ def concat_sequence_lens(self) -> Dict[str, Optional[int]]: } @property - def combine_dim_keys(self) -> List[DimKey]: - return [DimKey(dim.name, dim.operation) for dim in self.combine_dims] + def combine_dim_keys(self) -> List[Dimension]: + return [Dimension(dim.name, dim.operation) for dim in self.combine_dims] def __getitem__(self, indexer: Index) -> str: """Get a filename path for a particular key.""" assert len(indexer) == len(self.combine_dims) format_function_kwargs = {} - for dimkey, dimval in indexer.items(): + for dimension, position in indexer.items(): dims = [ - dim - for dim in self.combine_dims - if dim.name == dimkey.name and dim.operation == dimkey.operation + combine_dim + for combine_dim in self.combine_dims + if combine_dim.dimension == dimension ] if len(dims) != 1: - raise KeyError(r"Could not valid combine_dim for indexer {idx_key}") + raise KeyError(f"Could not valid combine_dim for dimension {dimension}") dim = dims[0] - format_function_kwargs[dim.name] = dim.keys[dimval.position] + format_function_kwargs[dim.name] = dim.keys[position.value] fname = self.format_function(**format_function_kwargs) return fname @@ -262,7 +196,10 @@ def __iter__(self) -> Iterator[Index]: """Iterate over all keys in the pattern.""" for val in product(*[range(n) for n in self.shape]): index = Index( - {DimKey(op.name, op.operation): DimVal(v) for op, v in zip(self.combine_dims, val)} + { + Dimension(op.name, op.operation): Position(v) + for op, v in zip(self.combine_dims, val) + } ) yield index diff --git a/pangeo_forge_recipes/types.py b/pangeo_forge_recipes/types.py new file mode 100644 index 00000000..80a15235 --- /dev/null +++ b/pangeo_forge_recipes/types.py @@ -0,0 +1,73 @@ +from dataclasses import dataclass +from enum import Enum +from typing import Dict, Optional + + +class CombineOp(Enum): + """Used to uniquely identify different combine operations across Pangeo Forge Recipes.""" + + MERGE = 1 + CONCAT = 2 + SUBSET = 3 + + +@dataclass(frozen=True, order=True) +class Dimension: + """ + :param name: The name of the dimension we are combining over. + :param operation: What type of combination this is (merge or concat) + """ + + name: str + operation: CombineOp + + +@dataclass(frozen=True, order=True) +class Position: + """ + :param indexed: If True, this position represents an offset within a dataset + If False, it is a position within a sequence. + """ + + value: int + # TODO: consider using a ClassVar here + indexed: bool = False + + +@dataclass(frozen=True, order=True) +class IndexedPosition(Position): + indexed: bool = True + + +class Index(Dict[Dimension, Position]): + """An Index is a special sort of dictionary which describes a position within + a multidimensional set. + + - The key is a :class:`Dimension` which tells us which dimension we are addressing. + - The value is a :class:`Position` which tells us where we are within that dimension. + + This object is hashable and deterministically serializable. + """ + + def __hash__(self): + return hash(tuple(self.__getstate__())) + + def __getstate__(self): + return sorted(self.items()) + + def __setstate__(self, state): + self.__init__({k: v for k, v in state}) + + def find_concat_dim(self, dim_name: str) -> Optional[Dimension]: + possible_concat_dims = [ + d for d in self if (d.name == dim_name and d.operation == CombineOp.CONCAT) + ] + if len(possible_concat_dims) > 1: + raise ValueError( + f"Found {len(possible_concat_dims)} concat dims named {dim_name} " + f"in the index {self}." + ) + elif len(possible_concat_dims) == 0: + return None + else: + return possible_concat_dims[0] diff --git a/tests/test_patterns.py b/tests/test_patterns.py index 68f679de..ff2b828c 100644 --- a/tests/test_patterns.py +++ b/tests/test_patterns.py @@ -5,7 +5,6 @@ from pangeo_forge_recipes.patterns import ( CombineOp, ConcatDim, - DimVal, FilePattern, FileType, MergeDim, @@ -13,6 +12,7 @@ pattern_from_file_sequence, prune_pattern, ) +from pangeo_forge_recipes.types import IndexedPosition, Position @pytest.fixture @@ -82,8 +82,8 @@ def test_pattern_from_file_sequence(): assert fp.nitems_per_input == {"time": None} assert fp.concat_sequence_lens == {"time": None} for index in fp: - dimval = next(iter(index.values())) - assert fp[index] == file_sequence[dimval.position] + position = next(iter(index.values())) + assert fp[index] == file_sequence[position.value] @pytest.mark.parametrize("pickle", [False, True]) @@ -124,14 +124,14 @@ def test_file_pattern_concat_merge(runtime_secrets, pickle, concat_merge_pattern for index in fp: concat_dim_key = index.find_concat_dim("time") assert index.find_concat_dim("foobar") is None - for dimkey, dimval in index.items(): - if dimkey.name == "time": - assert dimkey.operation == CombineOp.CONCAT - time_val = times[dimval.position] - assert dimval == index[concat_dim_key] - if dimkey.name == "variable": - assert dimkey.operation == CombineOp.MERGE - variable_val = varnames[dimval.position] + for dimension, position in index.items(): + if dimension.name == "time": + assert dimension.operation == CombineOp.CONCAT + time_val = times[position.value] + assert position == index[concat_dim_key] + if dimension.name == "variable": + assert dimension.operation == CombineOp.MERGE + variable_val = varnames[position.value] expected_fname = format_function(time=time_val, variable=variable_val) assert fp[index] == expected_fname @@ -203,11 +203,11 @@ def test_setting_file_types(file_type_value): @pytest.mark.parametrize( - "position,start,stop", - [(0, 0, 2), (1, 2, 4), (2, 4, 7), (3, 7, 9), (4, 9, 11)], + "position,start", + [(0, 0), (1, 2), (2, 4), (3, 7), (4, 9)], ) -def test_augment_index_with_start_stop(position, start, stop): - dk = DimVal(position) - expected = DimVal(position, start, stop) +def test_augment_index_with_start_stop(position, start): + dk = Position(position) + expected = IndexedPosition(start) actual = augment_index_with_start_stop(dk, [2, 2, 3, 2, 2]) assert actual == expected From f035f602bdaf868c09b15d097d4a9270cd90950c Mon Sep 17 00:00:00 2001 From: Ryan Abernathey Date: Sun, 21 Aug 2022 19:45:57 +0200 Subject: [PATCH 11/13] fix combiners --- pangeo_forge_recipes/combiners.py | 10 +++++----- pangeo_forge_recipes/transforms.py | 18 +++++++++--------- tests/test_combiners.py | 25 ++++++++++++++----------- 3 files changed, 28 insertions(+), 25 deletions(-) diff --git a/pangeo_forge_recipes/combiners.py b/pangeo_forge_recipes/combiners.py index 5538b9f8..0307bbcc 100644 --- a/pangeo_forge_recipes/combiners.py +++ b/pangeo_forge_recipes/combiners.py @@ -6,7 +6,7 @@ import apache_beam as beam from .aggregation import XarrayCombineAccumulator, XarraySchema -from .patterns import CombineOp, DimKey, Index +from .types import CombineOp, Dimension, Index @dataclass @@ -14,16 +14,16 @@ class CombineXarraySchemas(beam.CombineFn): """A beam ``CombineFn`` which we can use to combine multiple xarray schemas along a single dimension - :param dim_key: The dimension along which to combine + :param dimension: The dimension along which to combine """ - dim_key: DimKey + dimension: Dimension def get_position(self, index: Index) -> int: - return index[self.dim_key].position + return index[self.dimension].value def create_accumulator(self) -> XarrayCombineAccumulator: - concat_dim = self.dim_key.name if self.dim_key.operation == CombineOp.CONCAT else None + concat_dim = self.dimension.name if self.dimension.operation == CombineOp.CONCAT else None return XarrayCombineAccumulator(concat_dim=concat_dim) def add_input(self, accumulator: XarrayCombineAccumulator, item: Tuple[Index, XarraySchema]): diff --git a/pangeo_forge_recipes/transforms.py b/pangeo_forge_recipes/transforms.py index 6bc68594..519a0a18 100644 --- a/pangeo_forge_recipes/transforms.py +++ b/pangeo_forge_recipes/transforms.py @@ -11,7 +11,7 @@ from .aggregation import XarraySchema, dataset_to_schema, schema_to_zarr from .combiners import CombineXarraySchemas from .openers import open_url, open_with_xarray -from .patterns import CombineOp, DimKey, FileType, Index, augment_index_with_start_stop +from .patterns import CombineOp, Dimension, FileType, Index, augment_index_with_start_stop from .storage import CacheFSSpecTarget, FSSpecTarget from .writers import store_dataset_fragment @@ -124,10 +124,10 @@ def expand(self, pcoll): ) -def _nest_dim(item: Indexed[T], dim_key: DimKey) -> Indexed[Indexed[T]]: +def _nest_dim(item: Indexed[T], dimension: Dimension) -> Indexed[Indexed[T]]: index, value = item - inner_index = Index({dim_key: index[dim_key]}) - outer_index = Index({dk: index[dk] for dk in index if dk != dim_key}) + inner_index = Index({dimension: index[dimension]}) + outer_index = Index({dk: index[dk] for dk in index if dk != dimension}) return outer_index, (inner_index, value) @@ -136,13 +136,13 @@ class _NestDim(beam.PTransform): """Prepare a collection for grouping by transforming an Index into a nested Tuple of Indexes. - :param dim_key: The dimension to nest + :param dimension: The dimension to nest """ - dim_key: DimKey + dimension: Dimension def expand(self, pcoll): - return pcoll | beam.Map(_nest_dim, dim_key=self.dim_key) + return pcoll | beam.Map(_nest_dim, dimension=self.dimension) @dataclass @@ -159,7 +159,7 @@ class DetermineSchema(beam.PTransform): :param combine_dims: The dimensions to combine """ - combine_dims: List[DimKey] + combine_dims: List[Dimension] def expand(self, pcoll: beam.PCollection) -> beam.PCollection: cdims = self.combine_dims.copy() @@ -250,7 +250,7 @@ class StoreToZarr(beam.PTransform): # TODO: make it so we don't have to explictly specify combine_dims # Could be inferred from the pattern instead - combine_dims: List[DimKey] + combine_dims: List[Dimension] target_url: str target_chunks: Dict[str, int] = field(default_factory=dict) diff --git a/tests/test_combiners.py b/tests/test_combiners.py index 54de760b..68e8afc0 100644 --- a/tests/test_combiners.py +++ b/tests/test_combiners.py @@ -8,8 +8,9 @@ from pangeo_forge_recipes.aggregation import dataset_to_schema from pangeo_forge_recipes.combiners import CombineXarraySchemas -from pangeo_forge_recipes.patterns import CombineOp, DimKey, FilePattern, Index +from pangeo_forge_recipes.patterns import FilePattern from pangeo_forge_recipes.transforms import DetermineSchema, _NestDim +from pangeo_forge_recipes.types import CombineOp, Dimension, Index @pytest.fixture @@ -96,7 +97,7 @@ def test_CombineXarraySchemas_concat_1D(schema_pcoll_concat, pipeline): with pipeline as p: input = p | pcoll output = input | beam.CombineGlobally( - CombineXarraySchemas(DimKey(name=concat_dim, operation=CombineOp.CONCAT)) + CombineXarraySchemas(Dimension(name=concat_dim, operation=CombineOp.CONCAT)) ) assert_that(output, has_correct_schema(expected_schema)) @@ -132,12 +133,12 @@ def _check(actual): input = p | pcoll group1 = ( input - | "Nest CONCAT" >> _NestDim(DimKey("time", CombineOp.CONCAT)) + | "Nest CONCAT" >> _NestDim(Dimension("time", CombineOp.CONCAT)) | "Groupby CONCAT" >> beam.GroupByKey() ) group2 = ( input - | "Nest MERGE" >> _NestDim(DimKey("variable", CombineOp.MERGE)) + | "Nest MERGE" >> _NestDim(Dimension("variable", CombineOp.MERGE)) | "Groupy MERGE" >> beam.GroupByKey() ) assert_that(group1, check_key(merge_only_indexes, concat_only_indexes), label="merge") @@ -150,21 +151,23 @@ def test_DetermineSchema_concat_1D(schema_pcoll_concat, pipeline): with pipeline as p: input = p | pcoll - output = input | DetermineSchema([DimKey(name=concat_dim, operation=CombineOp.CONCAT)]) + output = input | DetermineSchema([Dimension(name=concat_dim, operation=CombineOp.CONCAT)]) assert_that(output, has_correct_schema(expected_schema), label="correct schema") -_dimkeys = [ - DimKey("time", operation=CombineOp.CONCAT), - DimKey("variable", operation=CombineOp.MERGE), +_dimensions = [ + Dimension("time", operation=CombineOp.CONCAT), + Dimension("variable", operation=CombineOp.MERGE), ] -@pytest.mark.parametrize("dimkeys", [_dimkeys, _dimkeys[::-1]], ids=["concat_first", "merge_first"]) -def test_DetermineSchema_concat_merge(dimkeys, schema_pcoll_concat_merge, pipeline): +@pytest.mark.parametrize( + "dimensions", [_dimensions, _dimensions[::-1]], ids=["concat_first", "merge_first"] +) +def test_DetermineSchema_concat_merge(dimensions, schema_pcoll_concat_merge, pipeline): pattern, expected_schema, pcoll = schema_pcoll_concat_merge with pipeline as p: input = p | pcoll - output = input | DetermineSchema(dimkeys) + output = input | DetermineSchema(dimensions) assert_that(output, has_correct_schema(expected_schema)) From ccbb967e25c003c726e774b8953abfb676018c76 Mon Sep 17 00:00:00 2001 From: Ryan Abernathey Date: Sun, 21 Aug 2022 19:59:15 +0200 Subject: [PATCH 12/13] fix writers --- pangeo_forge_recipes/writers.py | 17 +++++++++-------- tests/test_writers.py | 25 +++++++++++-------------- 2 files changed, 20 insertions(+), 22 deletions(-) diff --git a/pangeo_forge_recipes/writers.py b/pangeo_forge_recipes/writers.py index 31aa74bd..08aa9fc0 100644 --- a/pangeo_forge_recipes/writers.py +++ b/pangeo_forge_recipes/writers.py @@ -10,13 +10,14 @@ def _region_for(var: xr.Variable, index: Index) -> Tuple[slice, ...]: region_slice = [] for dim, dimsize in var.sizes.items(): - concat_dim_key = index.find_concat_dim(dim) - if concat_dim_key: + concat_dimension = index.find_concat_dim(dim) + if concat_dimension: # we are concatenating over this dimension - concat_dim_val = index[concat_dim_key] - assert concat_dim_val.start is not None - assert concat_dim_val.stop == concat_dim_val.start + dimsize - region_slice.append(slice(concat_dim_val.start, concat_dim_val.stop)) + position = index[concat_dimension] + assert position.indexed + start = position.value + stop = start + dimsize + region_slice.append(slice(start, stop)) else: # we are writing the entire dimension region_slice.append(slice(None)) @@ -37,7 +38,7 @@ def _store_data(vname: str, var: xr.Variable, index: Index, zgroup: zarr.Group) def _is_first_item(index): for _, v in index.items(): - if v.position > 0: + if v.value > 0: return False return True @@ -45,7 +46,7 @@ def _is_first_item(index): def _is_first_in_merge_dim(index): for k, v in index.items(): if k.operation == CombineOp.MERGE: - if v.position > 0: + if v.value > 0: return False return True diff --git a/tests/test_writers.py b/tests/test_writers.py index 333da08f..26ddbbe0 100644 --- a/tests/test_writers.py +++ b/tests/test_writers.py @@ -3,7 +3,7 @@ import zarr from pangeo_forge_recipes.aggregation import schema_to_zarr -from pangeo_forge_recipes.patterns import CombineOp, DimKey, DimVal, Index +from pangeo_forge_recipes.types import CombineOp, Dimension, Index, IndexedPosition, Position from pangeo_forge_recipes.writers import store_dataset_fragment from .data_generation import make_ds @@ -48,8 +48,8 @@ def test_store_dataset_fragment(temp_store): fragment_1_1 = ds[["bar"]].isel(time=slice(2, 4)) index_1_1 = Index( { - DimKey("time", CombineOp.CONCAT): DimVal(position=1, start=2, stop=4), - DimKey("variable", CombineOp.MERGE): DimVal(position=1), + Dimension("time", CombineOp.CONCAT): IndexedPosition(2), + Dimension("variable", CombineOp.MERGE): Position(1), } ) @@ -70,8 +70,8 @@ def test_store_dataset_fragment(temp_store): fragment_0_1 = ds[["foo"]].isel(time=slice(2, 4)) index_0_1 = Index( { - DimKey("time", CombineOp.CONCAT): DimVal(position=1, start=2, stop=4), - DimKey("variable", CombineOp.MERGE): DimVal(position=0), + Dimension("time", CombineOp.CONCAT): IndexedPosition(2), + Dimension("variable", CombineOp.MERGE): Position(0), } ) @@ -93,8 +93,8 @@ def test_store_dataset_fragment(temp_store): fragment_0_0 = ds[["foo"]].isel(time=slice(0, 2)) index_0_0 = Index( { - DimKey("time", CombineOp.CONCAT): DimVal(position=0, start=0, stop=2), - DimKey("variable", CombineOp.MERGE): DimVal(position=0), + Dimension("time", CombineOp.CONCAT): IndexedPosition(0), + Dimension("variable", CombineOp.MERGE): Position(0), } ) @@ -114,8 +114,8 @@ def test_store_dataset_fragment(temp_store): fragment_1_0 = ds[["bar"]].isel(time=slice(0, 2)) index_1_0 = Index( { - DimKey("time", CombineOp.CONCAT): DimVal(position=0, start=0, stop=2), - DimKey("variable", CombineOp.MERGE): DimVal(position=1), + Dimension("time", CombineOp.CONCAT): IndexedPosition(0), + Dimension("variable", CombineOp.MERGE): Position(1), } ) @@ -126,13 +126,10 @@ def test_store_dataset_fragment(temp_store): # now store everything else for nvar, vname in enumerate(["foo", "bar"]): for t_start in range(4, 10, 2): - ntime = t_start // 2 index = Index( { - DimKey("time", CombineOp.CONCAT): DimVal( - position=ntime, start=t_start, stop=t_start + 2 - ), - DimKey("variable", CombineOp.MERGE): DimVal(position=nvar), + Dimension("time", CombineOp.CONCAT): IndexedPosition(t_start), + Dimension("variable", CombineOp.MERGE): Position(nvar), } ) fragment = ds[[vname]].isel(time=slice(t_start, t_start + 2)) From a3765363647964d6dce5abd6cb013b7436afe5a3 Mon Sep 17 00:00:00 2001 From: Ryan Abernathey Date: Mon, 29 Aug 2022 13:13:38 -0400 Subject: [PATCH 13/13] temporarily remove rechunking stuff to make review easier --- pangeo_forge_recipes/rechunking.py | 147 ----------------------------- tests/test_rechunking.py | 117 ----------------------- 2 files changed, 264 deletions(-) delete mode 100644 pangeo_forge_recipes/rechunking.py delete mode 100644 tests/test_rechunking.py diff --git a/pangeo_forge_recipes/rechunking.py b/pangeo_forge_recipes/rechunking.py deleted file mode 100644 index 4895fbb6..00000000 --- a/pangeo_forge_recipes/rechunking.py +++ /dev/null @@ -1,147 +0,0 @@ -import itertools -from typing import Dict, List, Tuple - -import numpy as np -import xarray as xr - -from .chunk_grid import ChunkGrid -from .patterns import CombineOp, DimKey, DimVal, Index - -ChunkDimDict = Dict[str, Tuple[int, int]] - -# group keys are a tuple of tuples like (("lon", 1), ("time", 0)) -# the ints are chunk indexes -# code should aways sort the key before emitting it -GroupKey = Tuple[Tuple[str, int], ...] - - -def split_fragment(fragment: Tuple[Index, xr.Dataset], target_chunks_and_dims: ChunkDimDict): - """Split a single indexed dataset fragment into sub-fragments, according to the - specified target chunks - - :param fragment: the indexed fragment. The index must have ``start`` and ``stop`` set. - :param target_chunks_and_dims: mapping from dimension name to a tuple of (chunksize, dimsize) - """ - - index, ds = fragment - chunk_grid = ChunkGrid.from_uniform_grid(target_chunks_and_dims) - - # fragment_slices tells us where this fragement lies within the global dataset - fragment_slices = {} # type: Dict[str, slice] - # keys_to_skip is used to track dimensions that are present in both - # concat dims and target chunks - keys_to_skip = [] # type: list[DimKey] - for dim in target_chunks_and_dims: - concat_dim_key = index.find_concat_dim(dim) - if concat_dim_key: - # this dimension is present in the fragment as a concat dim - concat_dim_val = index[concat_dim_key] - dim_slice = slice(concat_dim_val.start, concat_dim_val.stop) - keys_to_skip.append(concat_dim_key) - else: - # If there is a target_chunk that is NOT present as a concat_dim in the fragment, - # then we can assume that the entire span of that dimension is present in the dataset - # This would arise e.g. when decimating a contiguous dimension - dim_slice = slice(0, ds.dims[dim]) - fragment_slices[dim] = dim_slice - - target_chunk_slices = chunk_grid.array_slice_to_chunk_slice(fragment_slices) - - # each chunk we are going to yield is indexed by a "target chunk group", - # a tuple of tuples of the form (("lat", 1), ("time", 0)) - all_chunks = itertools.product( - *( - [(dim, n) for n in range(chunk_slice.start, chunk_slice.stop)] - for dim, chunk_slice in target_chunk_slices.items() - ) - ) - - # this iteration yields new fragments, indexed by their target chunk group - for target_chunk_group in all_chunks: - # now we need to figure out which piece of the fragment belongs in which chunk - chunk_array_slices = chunk_grid.chunk_index_to_array_slice(dict(target_chunk_group)) - sub_fragment_indexer = {} # passed to ds.isel - # initialize the new index with the items we want to keep from the original index - # TODO: think about whether we want to always rechunk concat dims - sub_fragment_index = Index({k: v for k, v in index.items() if k not in keys_to_skip}) - for dim, chunk_slice in chunk_array_slices.items(): - fragment_slice = fragment_slices[dim] - start = max(chunk_slice.start, fragment_slice.start) - stop = min(chunk_slice.stop, fragment_slice.stop) - sub_fragment_indexer[dim] = slice( - start - fragment_slice.start, stop - fragment_slice.start - ) - dim_key = DimKey(dim, CombineOp.CONCAT) - # I am getting the original "position" value from the original index - # Not sure if this makes sense. There is no way to know the actual position here - # without knowing all the previous subfragments - original_position = getattr(index.get(dim_key), "position", 0) - sub_fragment_index[dim_key] = DimVal(original_position, start, stop) - sub_fragment_ds = ds.isel(**sub_fragment_indexer) - yield tuple(sorted(target_chunk_group)), (sub_fragment_index, sub_fragment_ds) - - -def _sort_index_key(item): - index = item[0] - return tuple(index.items()) - - -def combine_fragments(fragments: List[Tuple[Index, xr.Dataset]]) -> Tuple[Index, xr.Dataset]: - """Combine multiple dataset fragments into a single fragment. - - Only combines concat dims; merge dims are not combined. - - :param fragments: indexed dataset fragments - """ - - # we are combining over all the concat dims found in the indexes - # first check indexes for consistency - fragments.sort(key=_sort_index_key) # this should sort by index - all_indexes = [item[0] for item in fragments] - first_index = all_indexes[0] - dim_keys = tuple(first_index) - if not all([tuple(index) == dim_keys for index in all_indexes]): - raise ValueError( - f"Cannot combine fragments for elements with different combine dims: {all_indexes}" - ) - concat_dims = [dim_key for dim_key in dim_keys if dim_key.operation == CombineOp.CONCAT] - other_dims = [dim_key for dim_key in dim_keys if dim_key.operation != CombineOp.CONCAT] - # initialize new index with non-concat dims - index_combined = Index({dim: first_index[dim] for dim in other_dims}) - dim_names_and_vals = { - dim_key.name: [index[dim_key] for index in all_indexes] for dim_key in concat_dims - } - for dim, dim_vals in dim_names_and_vals.items(): - for dim_val in dim_vals: - if dim_val.start is None or dim_val.stop is None: - raise ValueError("Can only comined indexed fragments.") - # check for contiguity - starts = [dim_val.start for dim_val in dim_vals][1:] - stops = [dim_val.stop for dim_val in dim_vals][:-1] - if not starts == stops: - raise ValueError( - f"Index starts and stops are not consistent for concat_dim {dim}: {dim_vals}" - ) - # Position is unneeded at this point, but we still have to provide it - # This API probably needs to change - combined_dim_val = DimVal(dim_vals[0].position, dim_vals[0].start, dim_vals[-1].stop) - index_combined[DimKey(dim, CombineOp.CONCAT)] = combined_dim_val - # now create the nested dataset structure we need - shape = tuple(len(dim_vals) for dim_vals in dim_names_and_vals.values()) - expected_dims = { - dim_name: (dim_vals[-1].stop - dim_vals[0].start) # type: ignore - for dim_name, dim_vals in dim_names_and_vals.items() - } - # some tricky workarounds to put xarray datasets into a nested list - all_datasets = np.empty(shape, dtype="O").ravel() - for n, fragment in enumerate(fragments): - all_datasets[n] = fragment[1] - dsets_to_concat = all_datasets.reshape(shape).tolist() - ds_combined = xr.combine_nested(dsets_to_concat, concat_dim=list(dim_names_and_vals)) - actual_dims = {dim: ds_combined.dims[dim] for dim in expected_dims} - if actual_dims != expected_dims: - raise ValueError( - f"Combined dataset dims {actual_dims} not the same as those expected" - f"from the index {expected_dims}" - ) - return index_combined, ds_combined diff --git a/tests/test_rechunking.py b/tests/test_rechunking.py deleted file mode 100644 index 3487772d..00000000 --- a/tests/test_rechunking.py +++ /dev/null @@ -1,117 +0,0 @@ -import pytest -import xarray as xr - -from pangeo_forge_recipes.patterns import CombineOp, DimKey, DimVal, Index -from pangeo_forge_recipes.rechunking import combine_fragments, split_fragment - -from .data_generation import make_ds - - -@pytest.mark.parametrize("offset", [0, 5]) # hypothetical offset of this fragment -@pytest.mark.parametrize("time_chunks", [1, 3, 5, 10, 11]) -def test_split_fragment(time_chunks, offset): - """A thorough test of 1D splitting logic that should cover all major edge cases.""" - - nt_total = 20 # the total size of the hypothetical dataset - target_chunks_and_dims = {"time": (time_chunks, nt_total)} - - nt = 10 - ds = make_ds(nt=nt) # this represents a single dataset fragment - dim_key = DimKey("time", CombineOp.CONCAT) - - extra_indexes = [ - (DimKey("foo", CombineOp.CONCAT), DimVal(0)), - (DimKey("bar", CombineOp.MERGE), DimVal(1)), - ] - - index = Index([(dim_key, DimVal(0, offset, offset + nt))] + extra_indexes) - - all_splits = list(split_fragment((index, ds), target_chunks_and_dims=target_chunks_and_dims)) - - group_keys = [item[0] for item in all_splits] - new_indexes = [item[1][0] for item in all_splits] - new_datasets = [item[1][1] for item in all_splits] - - for n in range(len(all_splits)): - chunk_number = offset // time_chunks + n - assert group_keys[n] == (("time", chunk_number),) - chunk_start = time_chunks * chunk_number - chunk_stop = min(time_chunks * (chunk_number + 1), nt_total) - fragment_start = max(chunk_start, offset) - fragment_stop = min(chunk_stop, fragment_start + time_chunks, offset + nt) - # other dimensions in the index should be passed through unchanged - assert new_indexes[n] == Index( - [(dim_key, DimVal(0, fragment_start, fragment_stop))] + extra_indexes - ) - start, stop = fragment_start - offset, fragment_stop - offset - xr.testing.assert_equal(new_datasets[n], ds.isel(time=slice(start, stop))) - - # make sure we got the whole dataset back - ds_concat = xr.concat(new_datasets, "time") - xr.testing.assert_equal(ds, ds_concat) - - -def test_split_multidim(): - """A simple test that checks whether splitting logic is applied correctly - for multiple dimensions.""" - - nt = 2 - ds = make_ds(nt=nt) - nlat = ds.dims["lat"] - dim_key = DimKey("time", CombineOp.CONCAT) - index = Index({dim_key: DimVal(0, 0, nt)}) - - time_chunks = 1 - lat_chunks = nlat // 2 - target_chunks_and_dims = {"time": (time_chunks, nt), "lat": (lat_chunks, nlat)} - - all_splits = list(split_fragment((index, ds), target_chunks_and_dims=target_chunks_and_dims)) - - group_keys = [item[0] for item in all_splits] - - assert group_keys == [ - (("lat", 0), ("time", 0)), - (("lat", 1), ("time", 0)), - (("lat", 0), ("time", 1)), - (("lat", 1), ("time", 1)), - ] - - for group_key, (fragment_index, fragment_ds) in all_splits: - n_lat_chunk = group_key[0][1] - n_time_chunk = group_key[1][1] - time_start, time_stop = n_time_chunk * time_chunks, (n_time_chunk + 1) * time_chunks - lat_start, lat_stop = n_lat_chunk * lat_chunks, (n_lat_chunk + 1) * lat_chunks - expected_index = Index( - { - DimKey("time", CombineOp.CONCAT): DimVal(0, time_start, time_stop), - DimKey("lat", CombineOp.CONCAT): DimVal(0, lat_start, lat_stop), - } - ) - assert fragment_index == expected_index - xr.testing.assert_equal( - fragment_ds, ds.isel(time=slice(time_start, time_stop), lat=slice(lat_start, lat_stop)) - ) - - -@pytest.mark.parametrize("time_chunk", [1, 2, 3, 5, 10]) -def test_combine_fragments(time_chunk): - """The function applied after GroupBy to combine fragments into a single chunk. - All concat dims that appear more than once are combined. - """ - - nt = 10 - ds = make_ds(nt=nt) - - fragments = [] - dim_key = DimKey("time", CombineOp.CONCAT) - for nfrag, start in enumerate(range(0, nt, time_chunk)): - stop = min(start + time_chunk, nt) - # we are ignoring position (first item) at this point - index_frag = Index({dim_key: DimVal(0, start, stop)}) - ds_frag = ds.isel(time=slice(start, stop)) - fragments.append((index_frag, ds_frag)) - - index, ds_comb = combine_fragments(fragments) - - assert index == Index({dim_key: DimVal(0, 0, nt)}) - xr.testing.assert_equal(ds, ds_comb)