Skip to content

Commit

Permalink
feat: Adding the variables parameter to the View.update function.
Browse files Browse the repository at this point in the history
Refs: #13
  • Loading branch information
Thomas Zilio committed Nov 19, 2024
1 parent 7faa803 commit 5d7e9a7
Show file tree
Hide file tree
Showing 4 changed files with 88 additions and 16 deletions.
2 changes: 1 addition & 1 deletion zcollection/collection/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -517,7 +517,7 @@ def update(
the variables are inferred by calling the function on the first
partition. In this case, it is important to ensure that the
function can be called twice on the same partition without
side-effects. Default is None.
side effects. Default is None.
**kwargs: The keyword arguments to pass to the function.
Raises:
Expand Down
13 changes: 7 additions & 6 deletions zcollection/merging/tests/test_merging.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,12 +45,13 @@ def test_update_fs(
"""Test the _update_fs function."""
generator = data.create_test_dataset(delayed=False)
zds = next(generator)
zds_sc = dask_client.scatter(zds)

partition_folder = local_fs.root.joinpath('variable=1')

zattrs = str(partition_folder.joinpath('.zattrs'))
future = dask_client.submit(_update_fs, str(partition_folder),
dask_client.scatter(zds), local_fs.fs)
future = dask_client.submit(_update_fs, str(partition_folder), zds_sc,
local_fs.fs)
dask_client.gather(future)
assert local_fs.exists(zattrs)

Expand All @@ -60,7 +61,7 @@ def test_update_fs(
try:
future = dask_client.submit(_update_fs,
str(partition_folder),
dask_client.scatter(zds),
zds_sc,
local_fs.fs,
synchronizer=ThrowError())
dask_client.gather(future)
Expand All @@ -83,13 +84,13 @@ def test_perform(
zds = next(generator)

path = str(local_fs.root.joinpath('variable=1'))
zds_sc = dask_client.scatter(zds)

future = dask_client.submit(_update_fs, path, dask_client.scatter(zds),
local_fs.fs)
future = dask_client.submit(_update_fs, path, zds_sc, local_fs.fs)
dask_client.gather(future)

future = dask_client.submit(perform,
dask_client.scatter(zds),
zds_sc,
path,
'time',
local_fs.fs,
Expand Down
20 changes: 13 additions & 7 deletions zcollection/view/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
from ..collection.callable_objects import MapCallable, PartitionCallable
from ..collection.detail import _try_infer_callable
from ..convenience import collection as convenience
from ..type_hints import ArrayLike
from .detail import (
ViewReference,
ViewUpdateCallable,
Expand Down Expand Up @@ -418,6 +417,7 @@ def update(
npartitions: int | None = None,
selected_variables: Iterable[str] | None = None,
trim: bool = True,
variables: Sequence[str] | None = None,
**kwargs,
) -> None:
"""Update a variable stored int the view.
Expand Down Expand Up @@ -446,6 +446,11 @@ def update(
trim: Whether to trim ``depth`` items from each partition after
calling ``func``. Set it to ``False`` if your function does
this for you.
variables: The list of variables updated by the function. If None,
the variables are inferred by calling the function on the first
partition. In this case, it is important to ensure that the
function can be called twice on the same partition without
side effects. Default is None.
args: The positional arguments to pass to the function.
kwargs: The keyword arguments to pass to the function.
Expand Down Expand Up @@ -485,16 +490,17 @@ def update(
'data is selected with the given filters.')
return

func_result: dict[str, ArrayLike] = _try_infer_callable(
func, datasets_list[0][0], self.view_ref.partition_properties.dim,
*args, **kwargs)
variables = variables or tuple(
_try_infer_callable(func, datasets_list[0][0],
self.view_ref.partition_properties.dim, *args,
**kwargs))
tuple(
map(
lambda varname: _assert_variable_handled(
self.view_ref.metadata, self.metadata, varname),
func_result))
variables))
_LOGGER.info('Updating variable %s',
', '.join(repr(item) for item in func_result))
', '.join(repr(item) for item in variables))

# Function to apply to each partition.
wrap_function: ViewUpdateCallable
Expand All @@ -509,7 +515,7 @@ def update(
)
else:
if selected_variables is not None and len(
set(func_result) & set(selected_variables)) == 0:
set(variables) & set(selected_variables)) == 0:
raise ValueError(
'If the depth is greater than 0, the selected variables '
'must contain the variables updated by the function.')
Expand Down
69 changes: 67 additions & 2 deletions zcollection/view/tests/test_view.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,13 @@
"""
from __future__ import annotations

import logging
import pathlib

import numpy
import pytest

from ... import collection, convenience, meta, partitioning, view
from ... import collection, convenience, dataset, meta, partitioning, view
# pylint: disable=unused-import # Need to import for fixtures
from ...tests.cluster import dask_client, dask_cluster
from ...tests.data import (
Expand All @@ -23,6 +24,7 @@
)
from ...tests.fixture import dask_arrays, numpy_arrays
from ...tests.fs import local_fs, s3, s3_base, s3_fs
from ...type_hints import ArrayLike
from ...view.detail import _calculate_axis_reference

# pylint: enable=unused-import
Expand Down Expand Up @@ -136,7 +138,7 @@ def update(zds, varname):

zds = instance.load(delayed=delayed)
assert zds is not None
numpy.all(zds.variables['var3'].values == 5)
assert numpy.all(zds.variables['var3'].values == 5)

indexers = instance.map(
lambda x: slice(0, x.dimensions['num_lines']) # type: ignore
Expand All @@ -161,6 +163,69 @@ def update(zds, varname):
filesystem=tested_fs.fs)


@pytest.mark.parametrize('fs', ['local_fs', 's3_fs'])
def test_view_update(
dask_client, # pylint: disable=redefined-outer-name,unused-argument
fs,
request,
caplog):
"""Test the creation of a view."""
tested_fs = request.getfixturevalue(fs)

create_test_collection(tested_fs, delayed=False)
instance = convenience.create_view(path=str(tested_fs.view),
view_ref=view.ViewReference(
str(tested_fs.collection),
tested_fs.fs),
filesystem=tested_fs.fs)

var_name = 'var3'
log_msg = 'Update called'

var = meta.Variable(name=var_name,
dtype=numpy.float64,
dimensions=('num_lines', 'num_pixels'))

instance.add_variable(var)

def to_zero(zds: dataset.Dataset, varname):
"""Update function used to set a variable to 0."""
logging.info(log_msg)
return {varname: zds.variables['var1'].values * 0}

instance.update(to_zero, var_name) # type: ignore

data = instance.load(delayed=False)
assert numpy.all(data.variables[var_name].values == 0)

def plus_one_with_log(zds: dataset.Dataset, varname):
"""Update function increasing a variable by 1."""
logging.info(log_msg)
return {varname: zds.variables[var_name].values + 1}

caplog.set_level(logging.INFO)
caplog.clear()

instance.update(plus_one_with_log, var_name) # type: ignore

# One log per partition + 1 log for the initial call
assert caplog.text.count(log_msg) == len(list(instance.partitions())) + 1

data = instance.load(delayed=False)
assert numpy.all(data.variables[var_name].values == 1)

caplog.clear()
instance.update(
plus_one_with_log, # type: ignore
var_name,
variables=[var_name])

assert caplog.text.count(log_msg) == len(list(instance.partitions()))

data = instance.load(delayed=False)
assert numpy.all(data.variables[var_name].values == 2)


@pytest.mark.parametrize('arg', ['local_fs', 's3_fs'])
def test_view_overlap(
dask_client, # pylint: disable=redefined-outer-name,unused-argument
Expand Down

0 comments on commit 5d7e9a7

Please sign in to comment.