From 5d7e9a78d0de8a14d4ae95bc2967f5fd8b5c5ed6 Mon Sep 17 00:00:00 2001 From: Thomas Zilio Date: Tue, 19 Nov 2024 15:34:47 +0100 Subject: [PATCH] feat: Adding the variables parameter to the View.update function. Refs: #13 --- zcollection/collection/__init__.py | 2 +- zcollection/merging/tests/test_merging.py | 13 +++-- zcollection/view/__init__.py | 20 ++++--- zcollection/view/tests/test_view.py | 69 ++++++++++++++++++++++- 4 files changed, 88 insertions(+), 16 deletions(-) diff --git a/zcollection/collection/__init__.py b/zcollection/collection/__init__.py index 8c09807..b78bb79 100644 --- a/zcollection/collection/__init__.py +++ b/zcollection/collection/__init__.py @@ -517,7 +517,7 @@ def update( the variables are inferred by calling the function on the first partition. In this case, it is important to ensure that the function can be called twice on the same partition without - side-effects. Default is None. + side effects. Default is None. **kwargs: The keyword arguments to pass to the function. Raises: diff --git a/zcollection/merging/tests/test_merging.py b/zcollection/merging/tests/test_merging.py index a8dc080..1e535c9 100644 --- a/zcollection/merging/tests/test_merging.py +++ b/zcollection/merging/tests/test_merging.py @@ -45,12 +45,13 @@ def test_update_fs( """Test the _update_fs function.""" generator = data.create_test_dataset(delayed=False) zds = next(generator) + zds_sc = dask_client.scatter(zds) partition_folder = local_fs.root.joinpath('variable=1') zattrs = str(partition_folder.joinpath('.zattrs')) - future = dask_client.submit(_update_fs, str(partition_folder), - dask_client.scatter(zds), local_fs.fs) + future = dask_client.submit(_update_fs, str(partition_folder), zds_sc, + local_fs.fs) dask_client.gather(future) assert local_fs.exists(zattrs) @@ -60,7 +61,7 @@ def test_update_fs( try: future = dask_client.submit(_update_fs, str(partition_folder), - dask_client.scatter(zds), + zds_sc, local_fs.fs, synchronizer=ThrowError()) dask_client.gather(future) @@ -83,13 +84,13 @@ def test_perform( zds = next(generator) path = str(local_fs.root.joinpath('variable=1')) + zds_sc = dask_client.scatter(zds) - future = dask_client.submit(_update_fs, path, dask_client.scatter(zds), - local_fs.fs) + future = dask_client.submit(_update_fs, path, zds_sc, local_fs.fs) dask_client.gather(future) future = dask_client.submit(perform, - dask_client.scatter(zds), + zds_sc, path, 'time', local_fs.fs, diff --git a/zcollection/view/__init__.py b/zcollection/view/__init__.py index 097b56d..a99a8ea 100644 --- a/zcollection/view/__init__.py +++ b/zcollection/view/__init__.py @@ -26,7 +26,6 @@ from ..collection.callable_objects import MapCallable, PartitionCallable from ..collection.detail import _try_infer_callable from ..convenience import collection as convenience -from ..type_hints import ArrayLike from .detail import ( ViewReference, ViewUpdateCallable, @@ -418,6 +417,7 @@ def update( npartitions: int | None = None, selected_variables: Iterable[str] | None = None, trim: bool = True, + variables: Sequence[str] | None = None, **kwargs, ) -> None: """Update a variable stored int the view. @@ -446,6 +446,11 @@ def update( trim: Whether to trim ``depth`` items from each partition after calling ``func``. Set it to ``False`` if your function does this for you. + variables: The list of variables updated by the function. If None, + the variables are inferred by calling the function on the first + partition. In this case, it is important to ensure that the + function can be called twice on the same partition without + side effects. Default is None. args: The positional arguments to pass to the function. kwargs: The keyword arguments to pass to the function. @@ -485,16 +490,17 @@ def update( 'data is selected with the given filters.') return - func_result: dict[str, ArrayLike] = _try_infer_callable( - func, datasets_list[0][0], self.view_ref.partition_properties.dim, - *args, **kwargs) + variables = variables or tuple( + _try_infer_callable(func, datasets_list[0][0], + self.view_ref.partition_properties.dim, *args, + **kwargs)) tuple( map( lambda varname: _assert_variable_handled( self.view_ref.metadata, self.metadata, varname), - func_result)) + variables)) _LOGGER.info('Updating variable %s', - ', '.join(repr(item) for item in func_result)) + ', '.join(repr(item) for item in variables)) # Function to apply to each partition. wrap_function: ViewUpdateCallable @@ -509,7 +515,7 @@ def update( ) else: if selected_variables is not None and len( - set(func_result) & set(selected_variables)) == 0: + set(variables) & set(selected_variables)) == 0: raise ValueError( 'If the depth is greater than 0, the selected variables ' 'must contain the variables updated by the function.') diff --git a/zcollection/view/tests/test_view.py b/zcollection/view/tests/test_view.py index 365249c..3182fd2 100644 --- a/zcollection/view/tests/test_view.py +++ b/zcollection/view/tests/test_view.py @@ -8,12 +8,13 @@ """ from __future__ import annotations +import logging import pathlib import numpy import pytest -from ... import collection, convenience, meta, partitioning, view +from ... import collection, convenience, dataset, meta, partitioning, view # pylint: disable=unused-import # Need to import for fixtures from ...tests.cluster import dask_client, dask_cluster from ...tests.data import ( @@ -23,6 +24,7 @@ ) from ...tests.fixture import dask_arrays, numpy_arrays from ...tests.fs import local_fs, s3, s3_base, s3_fs +from ...type_hints import ArrayLike from ...view.detail import _calculate_axis_reference # pylint: enable=unused-import @@ -136,7 +138,7 @@ def update(zds, varname): zds = instance.load(delayed=delayed) assert zds is not None - numpy.all(zds.variables['var3'].values == 5) + assert numpy.all(zds.variables['var3'].values == 5) indexers = instance.map( lambda x: slice(0, x.dimensions['num_lines']) # type: ignore @@ -161,6 +163,69 @@ def update(zds, varname): filesystem=tested_fs.fs) +@pytest.mark.parametrize('fs', ['local_fs', 's3_fs']) +def test_view_update( + dask_client, # pylint: disable=redefined-outer-name,unused-argument + fs, + request, + caplog): + """Test the creation of a view.""" + tested_fs = request.getfixturevalue(fs) + + create_test_collection(tested_fs, delayed=False) + instance = convenience.create_view(path=str(tested_fs.view), + view_ref=view.ViewReference( + str(tested_fs.collection), + tested_fs.fs), + filesystem=tested_fs.fs) + + var_name = 'var3' + log_msg = 'Update called' + + var = meta.Variable(name=var_name, + dtype=numpy.float64, + dimensions=('num_lines', 'num_pixels')) + + instance.add_variable(var) + + def to_zero(zds: dataset.Dataset, varname): + """Update function used to set a variable to 0.""" + logging.info(log_msg) + return {varname: zds.variables['var1'].values * 0} + + instance.update(to_zero, var_name) # type: ignore + + data = instance.load(delayed=False) + assert numpy.all(data.variables[var_name].values == 0) + + def plus_one_with_log(zds: dataset.Dataset, varname): + """Update function increasing a variable by 1.""" + logging.info(log_msg) + return {varname: zds.variables[var_name].values + 1} + + caplog.set_level(logging.INFO) + caplog.clear() + + instance.update(plus_one_with_log, var_name) # type: ignore + + # One log per partition + 1 log for the initial call + assert caplog.text.count(log_msg) == len(list(instance.partitions())) + 1 + + data = instance.load(delayed=False) + assert numpy.all(data.variables[var_name].values == 1) + + caplog.clear() + instance.update( + plus_one_with_log, # type: ignore + var_name, + variables=[var_name]) + + assert caplog.text.count(log_msg) == len(list(instance.partitions())) + + data = instance.load(delayed=False) + assert numpy.all(data.variables[var_name].values == 2) + + @pytest.mark.parametrize('arg', ['local_fs', 's3_fs']) def test_view_overlap( dask_client, # pylint: disable=redefined-outer-name,unused-argument