diff --git a/Changelog.rst b/Changelog.rst index 506baebb63..dc36568b30 100644 --- a/Changelog.rst +++ b/Changelog.rst @@ -3,11 +3,16 @@ version NEXTVERSION **2024-??-??** +* New method: `cf.Field.is_discrete_axis` + (https://github.com/NCAS-CMS/cf-python/issues/784) * Include the UM version as a field property when reading UM files (https://github.com/NCAS-CMS/cf-python/issues/777) -* Fix bug where `cf.example_fields` returned a `list` - of Fields rather than a `Fieldlist` +* Fix bug where `cf.example_fields` returned a `list` of Fields rather + than a `Fieldlist` (https://github.com/NCAS-CMS/cf-python/issues/725) +* Fix bug where combining UGRID fields erroneously creates an extra + axis and broadcasts over it + (https://github.com/NCAS-CMS/cf-python/issues/784) * Fix bug where `cf.normalize_slice` doesn't correctly handle certain cyclic slices (https://github.com/NCAS-CMS/cf-python/issues/774) diff --git a/cf/cellmethod.py b/cf/cellmethod.py index 67ebf72cbd..3bdfbf3659 100644 --- a/cf/cellmethod.py +++ b/cf/cellmethod.py @@ -486,8 +486,7 @@ def expand_intervals(self, inplace=False, i=False): @_deprecated_kwarg_check("i", version="3.0.0", removed_at="4.0.0") @_inplace_enabled(default=False) def change_axes(self, axis_map, inplace=False, i=False): - """Change the axes of the cell method according to a given - mapping. + """Replace the axes of the cell method. :Parameters: diff --git a/cf/field.py b/cf/field.py index 61132fb6b3..60601dc632 100644 --- a/cf/field.py +++ b/cf/field.py @@ -1,5 +1,5 @@ import logging -from collections import namedtuple +from dataclasses import dataclass from functools import reduce from operator import mul as operator_mul from os import sep @@ -190,11 +190,29 @@ "__ge__", ) -_xxx = namedtuple( - "data_dimension", ["size", "axis", "key", "coord", "coord_type", "scalar"] -) -# _empty_set = set() +@dataclass() +class _Axis_characterisation: + """Characterise a domain axis. + + Used by `_binary_operation` to help with ascertaining if there is + a common axis in two fields. + + .. versionaddedd:: NEXTVERSION + + """ + + # The size of the axis, an integer. + size: int = -1 + # The domain axis identifier. E.g. 'domainaxis0' + axis: str = "" + # The coordinate constructs that characterize the axis + coords: tuple = () + # The identifiers of the coordinate + # constructs. E.g. ('dimensioncoordinate1',) + keys: tuple = () + # Whether or not the axis is spanned by the field's data array + axis_in_data_axes: bool = True class Field(mixin.FieldDomain, mixin.PropertiesData, cfdm.Field): @@ -985,80 +1003,127 @@ def _binary_operation(self, other, method): data_axes = f.get_data_axes() for axis in f.domain_axes(todict=True): identity = None - key = None - coord = None - coord_type = None - key, coord = f.dimension_coordinate( - item=True, default=(None, None), filter_by_axis=(axis,) - ) - if coord is not None: - # This axis of the domain has a dimension - # coordinate - identity = coord.identity(strict=True, default=None) - if identity is None: - # Dimension coordinate has no identity, but it - # may have a recognised axis. - for ctype in ("T", "X", "Y", "Z"): - if getattr(coord, ctype, False): - identity = ctype - break - - if identity is None and relaxed_identities: - identity = coord.identity(relaxed=True, default=None) - else: - key, coord = f.auxiliary_coordinate( - item=True, - default=(None, None), + if self.is_discrete_axis(axis): + # This is a discrete axis whose identity is + # inferred from all of its auxiliary coordinates + x = {} + for key, aux_coord in f.auxiliary_coordinates( filter_by_axis=(axis,), - axis_mode="exact", + axis_mode="and", + todict=True, + ).items(): + identity = aux_coord.identity( + strict=True, default=None + ) + if identity is None and relaxed_identities: + identity = aux_coord.identity( + relaxed=True, default=None + ) + if identity is not None: + x[identity] = key + + if x: + # Get the sorted identities (sorted so that + # they're comparable between fields) and their + # corresponding keys. + # + # E.g. {2:3, 4:6, 1:7} -> (1, 2, 4), (7, 3, 6) + identity, keys = tuple(zip(*sorted(x.items()))) + coords = tuple( + f.auxiliary_coordinate(key) for key in keys + ) + else: + identity = None + keys = () + coords = () + else: + key, dim_coord = f.dimension_coordinate( + item=True, default=(None, None), filter_by_axis=(axis,) ) - if coord is not None: - # This axis of the domain does not have a - # dimension coordinate but it does have - # exactly one 1-d auxiliary coordinate, so - # that will do. - identity = coord.identity(strict=True, default=None) + if dim_coord is not None: + # This non-discrete axis has a dimension + # coordinate + identity = dim_coord.identity( + strict=True, default=None + ) + if identity is None: + # Dimension coordinate has no identity, + # but it may have a recognised axis. + for ctype in ("T", "X", "Y", "Z"): + if getattr(dim_coord, ctype, False): + identity = ctype + break if identity is None and relaxed_identities: - identity = coord.identity( + identity = dim_coord.identity( relaxed=True, default=None ) + keys = (key,) + coords = (dim_coord,) + else: + key, aux_coord = f.auxiliary_coordinate( + item=True, + default=(None, None), + filter_by_axis=(axis,), + axis_mode="exact", + ) + if aux_coord is not None: + # This non-discrete axis does not have a + # dimension coordinate but it does have + # exactly one 1-d auxiliary coordinate, so + # that will do. + coords = (aux_coord,) + identity = aux_coord.identity( + strict=True, default=None + ) + if identity is None and relaxed_identities: + identity = aux_coord.identity( + relaxed=True, default=None + ) + if identity is None: identity = i - else: - coord_type = coord.construct_type - out[identity] = _xxx( + out[identity] = _Axis_characterisation( size=f.domain_axis(axis).get_size(), axis=axis, - key=key, - coord=coord, - coord_type=coord_type, - scalar=(axis not in data_axes), + keys=keys, + coords=coords, + axis_in_data_axes=axis in data_axes, ) for identity, y in tuple(out1.items()): - asdas = True - if y.scalar and identity in out0 and isinstance(identity, str): + missing_axis_inserted = False + if ( + not y.axis_in_data_axes + and identity in out0 + and isinstance(identity, str) + ): a = out0[identity] if a.size > 1: + # Put missing axis into data axes field1.insert_dimension(y.axis, position=0, inplace=True) - asdas = False + missing_axis_inserted = True - if y.scalar and asdas: + if not missing_axis_inserted and not y.axis_in_data_axes: del out1[identity] for identity, a in tuple(out0.items()): - asdas = True - if a.scalar and identity in out1 and isinstance(identity, str): + missing_axis_inserted = False + if ( + not a.axis_in_data_axes + and identity in out1 + and isinstance(identity, str) + ): y = out1[identity] if y.size > 1: + # Put missing axis into data axes field0.insert_dimension(a.axis, position=0, inplace=True) - asdas = False + missing_axis_inserted = True - if a.scalar and asdas: + if not missing_axis_inserted and not a.axis_in_data_axes: del out0[identity] squeeze1 = [] @@ -1069,15 +1134,14 @@ def _binary_operation(self, other, method): axes_added_from_field1 = [] # Dictionary of size > 1 axes from field1 which will replace - # matching size 1 axes in field0. E.g. {'domainaxis1': - # data_dimension( - # size=8, - # axis='domainaxis1', - # key='dimensioncoordinate1', - # coord=, - # coord_type='dimension_coordinate', - # scalar=False - # ) + # matching size 1 axes in field0. + # + # E.g. {'domainaxis1': _Axis_characterisation( + # size=8, + # axis='domainaxis1', + # keys=('dimensioncoordinate1',), + # coords=(CF DimensionCoordinate: longitude(8) degrees_east>,), + # axis_in_data_axes=True) # } axes_to_replace_from_field1 = {} @@ -1178,48 +1242,55 @@ def _binary_operation(self, other, method): f"{a.size} {identity!r} axis" ) - # Ensure matching axis directions - if y.coord.direction() != a.coord.direction(): - other.flip(y.axis, inplace=True) + for a_key, a_coord, y_key, y_coord in zip( + a.keys, a.coords, y.keys, y.coords + ): + # Ensure matching axis directions + if y_coord.direction() != a_coord.direction(): + other.flip(y.axis, inplace=True) - # Check for matching coordinate values - if not y.coord._equivalent_data(a.coord): - raise ValueError( - f"Can't combine size {y.size} {identity!r} axes with " - f"non-matching coordinate values" - ) + # Check for matching coordinate values + if not y_coord._equivalent_data(a_coord): + raise ValueError( + f"Can't combine size {y.size} {identity!r} axes with " + f"non-matching coordinate values" + ) - # Check coord refs - refs1 = field1.get_coordinate_reference(construct=y.key, key=True) - refs0 = field0.get_coordinate_reference(construct=a.key, key=True) + # Check coord refs + refs1 = field1.get_coordinate_reference( + construct=y_key, key=True + ) + refs0 = field0.get_coordinate_reference( + construct=a_key, key=True + ) - n_refs = len(refs1) - n0_refs = len(refs0) + n_refs = len(refs1) + n0_refs = len(refs0) - if n_refs != n0_refs: - raise ValueError( - f"Can't combine {self.__class__.__name__!r} with " - f"{other.__class__.__name__!r} because the coordinate " - f"references have different lengths: {n_refs} and " - f"{n0_refs}." - ) + if n_refs != n0_refs: + raise ValueError( + f"Can't combine {self.__class__.__name__!r} with " + f"{other.__class__.__name__!r} because the coordinate " + f"references have different lengths: {n_refs} and " + f"{n0_refs}." + ) - n_equivalent_refs = 0 - for ref1 in refs1: - for ref0 in refs0[:]: - if field1._equivalent_coordinate_references( - field0, key0=ref1, key1=ref0, axis_map=axis_map - ): - n_equivalent_refs += 1 - refs0.remove(ref0) - break + n_equivalent_refs = 0 + for ref1 in refs1: + for ref0 in refs0[:]: + if field1._equivalent_coordinate_references( + field0, key0=ref1, key1=ref0, axis_map=axis_map + ): + n_equivalent_refs += 1 + refs0.remove(ref0) + break - if n_equivalent_refs != n_refs: - raise ValueError( - f"Can't combine {self.__class__.__name__!r} with " - f"{other.__class__.__name__!r} because the fields have " - "incompatible coordinate references." - ) + if n_equivalent_refs != n_refs: + raise ValueError( + f"Can't combine {self.__class__.__name__!r} with " + f"{other.__class__.__name__!r} because the fields " + "have incompatible coordinate references." + ) # Change the domain axis sizes in field0 so that they match # the broadcasted result data diff --git a/cf/mixin/fielddomain.py b/cf/mixin/fielddomain.py index 665658bb66..c8e4fb1bb2 100644 --- a/cf/mixin/fielddomain.py +++ b/cf/mixin/fielddomain.py @@ -2347,6 +2347,98 @@ def iscyclic(self, *identity, **filter_kwargs): return axis in self.cyclic() + def is_discrete_axis(self, *identity, **filter_kwargs): + """Return True if the given axis is discrete. + + In general, a discrete axis is any axis that does not + correspond to a continuous physical quantity, but only the + following types of discrete axis are identified here: + + * The feature instance axis of a discrete sampling geometry + (DSG) domain. + + * An axis spanned by the domain topology construct of an + unstructured grid. + + * The axis with geometry cells. + + .. versionaddedd:: NEXTVERSION + + .. seealso:: `domain_axis`, `coordinates` + + :Parameters: + + identity: `tuple`, optional + Select domain axis constructs that have an identity, + defined by their `!identities` methods, that matches + any of the given values. + + Additionally, the values are matched against construct + identifiers, with or without the ``'key%'`` prefix. + + Additionally, if for a given ``value``, + ``f.coordinates(value, filter_by_naxes=(1,))`` returns + 1-d coordinate constructs that all span the same + domain axis construct then that domain axis construct + is selected. See `coordinates` for details. + + Additionally, if there is a `Field` data array and a + value matches the integer position of an array + dimension, then the corresponding domain axis + construct is selected. + + If no values are provided then all domain axis + constructs are selected. + + {{value match}} + + {{displayed identity}} + + {{filter_kwargs: optional}} + + **Examples** + + >>> f = cf.example_{{class_lower}}(8) + >>> f.is_discrete_axis('X') + True + >>> f.is_discrete_axis('T') + False + + """ + # Get the axis key + axis = self.domain_axis(*identity, key=True, **filter_kwargs) + + # DSG + if self.has_property("featureType") and not self.dimension_coordinate( + filter_by_axis=(axis,), default=False + ): + ctypes = ("X", "Y", "T") + n = 0 + for aux in self.auxiliary_coordinates( + filter_by_axis=(axis,), axis_mode="and", todict=True + ).values(): + if aux.ctype in ctypes: + n += 1 + + if n == len(ctypes): + return True + + # UGRID + if self.domain_topologies( + filter_by_axis=(axis,), axis_mode="exact", todict=True + ): + return True + + # Geometries + for aux in self.auxiliary_coordinates( + filter_by_axis=(axis,), axis_mode="exact", todict=True + ).values(): + if aux.get_geometry(None): + return True + + # Still here? Then the axis is not discrete. + return False + def match_by_rank(self, *ranks): """Whether or not the number of domain axis constructs satisfies conditions. diff --git a/cf/test/test_Field.py b/cf/test/test_Field.py index 62048773e2..13366c52f2 100644 --- a/cf/test/test_Field.py +++ b/cf/test/test_Field.py @@ -749,6 +749,11 @@ def test_Field__add__(self): with self.assertRaises(TypeError): f + ("a string",) + # Addition with a UGRID discrete axis + f = cf.example_field(8) + g = f + f + self.assertEqual(g.shape, f.shape) + def test_Field__mul__(self): f = self.f.copy().squeeze() @@ -2883,6 +2888,28 @@ def test_Field_cyclic_iscyclic(self): self.assertEqual(f2.cyclic(), set(("domainaxis2",))) self.assertTrue(f2.iscyclic("X")) + def test_Field_is_discrete_axis(self): + """Test the `is_discrete_axis` Field method.""" + # No discrete axes + f = cf.example_field(1) + for axis in f.domain_axes(todict=True): + self.assertFalse(f.is_discrete_axis(axis)) + + # UGRID + f = cf.example_field(8) + self.assertTrue(f.is_discrete_axis("X")) + self.assertFalse(f.is_discrete_axis("T")) + + # DSG + f = cf.example_field(4) + self.assertTrue(f.is_discrete_axis("cf_role=timeseries_id")) + self.assertFalse(f.is_discrete_axis("ncdim%timeseries")) + + # Geometries + f = cf.example_field(6) + self.assertTrue(f.is_discrete_axis("cf_role=timeseries_id")) + self.assertFalse(f.is_discrete_axis("time")) + if __name__ == "__main__": print("Run date:", datetime.datetime.now()) diff --git a/cf/test/test_functions.py b/cf/test/test_functions.py index 370d0a9036..9cc5e5eab1 100644 --- a/cf/test/test_functions.py +++ b/cf/test/test_functions.py @@ -389,24 +389,24 @@ def test_normalize_slice(self): cf.normalize_slice(slice(2, 5, -2), 8, cyclic=True), slice(2, -3, -2), ) - + self.assertEqual( cf.normalize_slice(slice(-8, 0, 1), 8, cyclic=True), - slice(-8, 0, 1) - ) + slice(-8, 0, 1), + ) self.assertEqual( cf.normalize_slice(slice(0, 7, -1), 8, cyclic=True), - slice(0, -1, -1) - ) + slice(0, -1, -1), + ) self.assertEqual( cf.normalize_slice(slice(-1, -8, 1), 8, cyclic=True), - slice(-1, 0, 1) - ) + slice(-1, 0, 1), + ) self.assertEqual( cf.normalize_slice(slice(-8, -1, -1), 8, cyclic=True), - slice(0, -1, -1) - ) - + slice(0, -1, -1), + ) + with self.assertRaises(IndexError): cf.normalize_slice([1, 2], 8) diff --git a/docs/source/class/cf.Domain.rst b/docs/source/class/cf.Domain.rst index 3117abd118..e3d076733f 100644 --- a/docs/source/class/cf.Domain.rst +++ b/docs/source/class/cf.Domain.rst @@ -207,6 +207,7 @@ Domain axes ~cf.Domain.direction ~cf.Domain.directions ~cf.Domain.iscyclic + ~cf.Domain.is_discrete_axis Subspacing ---------- diff --git a/docs/source/class/cf.Field.rst b/docs/source/class/cf.Field.rst index c0a0275daa..8bbe0f5743 100644 --- a/docs/source/class/cf.Field.rst +++ b/docs/source/class/cf.Field.rst @@ -518,6 +518,7 @@ Domain axes ~cf.Field.direction ~cf.Field.directions ~cf.Field.iscyclic + ~cf.Field.is_discrete_axis ~cf.Field.isperiodic ~cf.Field.item_axes ~cf.Field.items_axes