From c443fd149d28572ed4e3334b84be8e86035dfd13 Mon Sep 17 00:00:00 2001 From: Ryan May Date: Mon, 30 Aug 2021 21:08:12 -0600 Subject: [PATCH 1/9] First cut at automated field solver This is a pretty straightforward Breadth-First Search (BFS) through a registry of calculations. The path accounts for what's available as we traverse through the "graph". --- src/metpy/calc/field_solver.py | 130 +++++++++++++++++++++++++++++++++ tests/calc/test_solver.py | 81 ++++++++++++++++++++ 2 files changed, 211 insertions(+) create mode 100644 src/metpy/calc/field_solver.py create mode 100644 tests/calc/test_solver.py diff --git a/src/metpy/calc/field_solver.py b/src/metpy/calc/field_solver.py new file mode 100644 index 00000000000..bebf04042a2 --- /dev/null +++ b/src/metpy/calc/field_solver.py @@ -0,0 +1,130 @@ +# Copyright (c) 2021 MetPy Developers. +# Distributed under the terms of the BSD 3-Clause License. +# SPDX-License-Identifier: BSD-3-Clause +"""Solver to automatically calculate derived parameters from a dataset.""" + +from collections import deque +import contextlib +from inspect import signature, Parameter + + +class Path: + def __init__(self, steps, have, need): + self.steps = steps + self.have = have + self.need = need + + @property + def need(self): + return self._need + + @need.setter + def need(self, val): + self._need = {i for i in val if i not in self.have} + + def is_complete(self): + return not bool(self.need) + + def __add__(self, other): + if any(f in set(self.steps) for f in other.steps): + raise ValueError(f'{other.steps} already in steps') + + # Prepend steps so that final path is in proper call order + return Path(other.steps + self.steps, self.have | other.have, self.need | other.need) + + def __str__(self): + return (f'Path') + + __repr__ = __str__ + + +class Solver: + names = {'Tw': 'wet_bulb_temperature', 'Td': 'dewpoint_temperature', + 'dewpoint': 'dewpoint_temperature', 'Tv': 'virtual_temperature', + 'q': 'specific_humidity', 'r': 'mixing_ratio', 'RH': 'relative_humidity', + 'p': 'pressure', 'T': 'temperature'} + + standard_names = {'temperature': 'air_temperature'} + + fallback_names = {'temperature': ['temp'], 'pressure': ['P']} + + def __init__(self): + self._graph = {} + self._funcs = {} + + def register(self, *args, inputs=None): + def dec(func): + nonlocal inputs + nonlocal args + if inputs is None: + funcsig = signature(func) + inputs = [name for name, param in funcsig.parameters.items() if + param.default is Parameter.empty] + + if not args: + args = (func.__name__,) + + normed_returns = self.normalize_names(args) + normed_inputs = self.normalize_names(inputs) + path = Path([func], set(normed_returns), set(normed_inputs)) + self._funcs[func] = (normed_inputs, normed_returns) + for ret in normed_returns: + self._graph.setdefault(ret, []).append(path) + return func + + return dec + + def normalize_names(self, names): + return [self.normalize(name) for name in names] + + def normalize(self, name): + return self.names.get(name, name) + + def _map_func_args(self, func, data): + for item in self._funcs[func][0]: + if item in data: + yield data[item] + elif item in self.standard_names: + ds = data.filter_by_attrs(standard_name=self.standard_names[item]) + yield next(iter(ds)) + else: + for name in self.fallback_names.get(item, []): + if name in data: + yield data[name] + + def calculate(self, data, name): + data = data.copy() + for func in self.solve(set(data), name): + result = func(*self._map_func_args(func, data)) + retname = self._funcs[func][-1] + if isinstance(result, tuple): + for name, val in zip(retname, result): + data[name] = val + else: + data[retname] = result + + return data[self.normalize(name)] + + def solve(self, have, want): + # Using deque as a FIFO queue by pushing at one end and popping + # from the other--this makes this a Breadth-First Search + options = deque([Path([], set(self.normalize_names(have)), {self.normalize(want)})]) + while options: + path = options.popleft() + # If calculation path is complete, return the steps + if path.is_complete(): + return path.steps + else: + # Otherwise grab one of the remaining needs and + # add all methods for calculating to the current steps + # and make them options to consider + item = path.need.pop() + for trial_step in self._graph.get(item, ()): + # with contextlib.suppress(ValueError): + options.append(path + trial_step) + + raise ValueError(f'Unable to calculate {want} from {have}') + + +solver = Solver() diff --git a/tests/calc/test_solver.py b/tests/calc/test_solver.py new file mode 100644 index 00000000000..c3b4e63382b --- /dev/null +++ b/tests/calc/test_solver.py @@ -0,0 +1,81 @@ +# Copyright (c) 2021 MetPy Developers. +# Distributed under the terms of the BSD 3-Clause License. +# SPDX-License-Identifier: BSD-3-Clause +"""Test the solver.""" +import pytest + +from metpy.calc.field_solver import Solver + +test_solver = Solver() + +@test_solver.register('Td') +def dewpoint_from_relative_humidity(temperature, relative_humidity): + pass + + +@test_solver.register('RH') +def relative_humidity_from_specific_humidity(pressure, temperature, specific_humidity): + pass + + +@test_solver.register('r') +def mixing_ratio_from_relative_humidity(pressure, temperature, relative_humidity): + pass + + +@test_solver.register('Tv') +def virtual_temperature(temperature, mixing_ratio, molecular_weight_ratio=5): + pass + + +@test_solver.register('RH') +def relative_humidity_from_mixing_ratio(pressure, temperature, mixing_ratio): + pass + + +@test_solver.register('r') +def mixing_ratio_from_specific_humidity(specific_humidity): + pass + + +@test_solver.register('q') +def specific_humidity_from_mixing_ratio(mixing_ratio): + pass + + +@test_solver.register('rho') +def density(pressure, temperature, mixing_ratio): + pass + + +@test_solver.register('Tw') +def wet_bulb_temperature(pressure, temperature, dewpoint): + pass + + +@test_solver.register('u', 'v') +def wind_components(wind_speed, wind_direction): + pass + +@test_solver.register() +def vorticity(u, v): + pass + +@pytest.mark.parametrize(['inputs', 'want', 'truth'], [ + ({'T', 'RH'}, 'Td', [dewpoint_from_relative_humidity]), + ({'T', 'p', 'q'}, 'Td', [relative_humidity_from_specific_humidity, + dewpoint_from_relative_humidity]), + ({'T', 'p', 'q'}, 'Tv', [mixing_ratio_from_specific_humidity, virtual_temperature]), + ({'p', 'T', 'RH'}, 'rho', [mixing_ratio_from_relative_humidity, density]), + ({'p', 'T', 'RH'}, 'Tw', [dewpoint_from_relative_humidity, wet_bulb_temperature]), + ({'wind_speed', 'wind_direction'}, 'vorticity', [wind_components, vorticity]) +]) +def test_solutions(inputs, want, truth): + """Test that the proper sequence of calculations is found.""" + assert test_solver.solve(inputs, want) == truth + + +def test_failure(): + """Test that the correct error results when a value cannot be solved.""" + with pytest.raises(ValueError): + test_solver.solve({'RH'}, 'Td') From 9cf537f1263b3360bad0449b34d00e04fe4b5027 Mon Sep 17 00:00:00 2001 From: Ryan May Date: Mon, 30 Aug 2021 21:30:46 -0600 Subject: [PATCH 2/9] Register some calculations with the solver Add a basic test for the calculate functionality. --- src/metpy/calc/__init__.py | 4 +++- src/metpy/calc/basic.py | 6 ++++++ src/metpy/calc/thermo.py | 17 +++++++++++++++++ tests/calc/test_solver.py | 14 ++++++++++++++ 4 files changed, 40 insertions(+), 1 deletion(-) diff --git a/src/metpy/calc/__init__.py b/src/metpy/calc/__init__.py index 736a35f0913..68eb6a0e92d 100644 --- a/src/metpy/calc/__init__.py +++ b/src/metpy/calc/__init__.py @@ -6,6 +6,7 @@ from .basic import * # noqa: F403 from .cross_sections import * # noqa: F403 from .exceptions import * # noqa: F403 +from .field_solver import solver from .indices import * # noqa: F403 from .kinematics import * # noqa: F403 from .thermo import * # noqa: F403 @@ -13,7 +14,8 @@ from .turbulence import * # noqa: F403 from ..package_tools import set_module -__all__ = basic.__all__[:] # pylint: disable=undefined-variable +__all__ = ['solver'] +__all__.extend(basic.__all__) # pylint: disable=undefined-variable __all__.extend(cross_sections.__all__) # pylint: disable=undefined-variable __all__.extend(indices.__all__) # pylint: disable=undefined-variable __all__.extend(kinematics.__all__) # pylint: disable=undefined-variable diff --git a/src/metpy/calc/basic.py b/src/metpy/calc/basic.py index 7c53d25be86..773581d16c1 100644 --- a/src/metpy/calc/basic.py +++ b/src/metpy/calc/basic.py @@ -16,6 +16,7 @@ import numpy as np from scipy.ndimage import gaussian_filter +from . import solver from .. import constants as mpconsts from ..package_tools import Exporter from ..units import check_units, masked_array, units @@ -30,6 +31,7 @@ @exporter.export +@solver.register() @preprocess_and_wrap(wrap_like='u') @check_units('[speed]', '[speed]') def wind_speed(u, v): @@ -57,6 +59,7 @@ def wind_speed(u, v): @exporter.export +@solver.register() @preprocess_and_wrap(wrap_like='u') @check_units('[speed]', '[speed]') def wind_direction(u, v, convention='from'): @@ -112,6 +115,7 @@ def wind_direction(u, v, convention='from'): @exporter.export +@solver.register('u', 'v') @preprocess_and_wrap(wrap_like=('speed', 'speed')) @check_units('[speed]') def wind_components(speed, wind_direction): @@ -153,6 +157,7 @@ def wind_components(speed, wind_direction): @exporter.export +@solver.register('temperature', 'wind_speed') @preprocess_and_wrap(wrap_like='temperature') @check_units(temperature='[temperature]', speed='[speed]') def windchill(temperature, speed, face_level_winds=False, mask_undefined=True): @@ -215,6 +220,7 @@ def windchill(temperature, speed, face_level_winds=False, mask_undefined=True): @exporter.export +@solver.register() @preprocess_and_wrap(wrap_like='temperature') @check_units('[temperature]') def heat_index(temperature, relative_humidity, mask_undefined=True): diff --git a/src/metpy/calc/thermo.py b/src/metpy/calc/thermo.py index 35053e8ef68..460e49d327c 100644 --- a/src/metpy/calc/thermo.py +++ b/src/metpy/calc/thermo.py @@ -10,6 +10,7 @@ import scipy.optimize as so import xarray as xr +from . import solver from .exceptions import InvalidSoundingError from .tools import (_greater_or_close, _less_or_close, _remove_nans, find_bounding_indices, find_intersections, first_derivative, get_layer) @@ -26,6 +27,7 @@ @exporter.export +@solver.register('relative_humidity') @preprocess_and_wrap(wrap_like='temperature', broadcast=('temperature', 'dewpoint')) @check_units('[temperature]', '[temperature]') def relative_humidity_from_dewpoint(temperature, dewpoint): @@ -101,6 +103,7 @@ def exner_function(pressure, reference_pressure=mpconsts.P0): @exporter.export +@solver.register() @preprocess_and_wrap(wrap_like='temperature', broadcast=('pressure', 'temperature')) @check_units('[pressure]', '[temperature]') def potential_temperature(pressure, temperature): @@ -148,6 +151,7 @@ def potential_temperature(pressure, temperature): broadcast=('pressure', 'potential_temperature') ) @check_units('[pressure]', '[temperature]') +@solver.register('temperature') def temperature_from_potential_temperature(pressure, potential_temperature): r"""Calculate the temperature from a given potential temperature. @@ -1011,6 +1015,7 @@ def saturation_vapor_pressure(temperature): @exporter.export +@solver.register('dewpoint') @preprocess_and_wrap(wrap_like='temperature', broadcast=('temperature', 'relative_humidity')) @check_units('[temperature]', '[dimensionless]') def dewpoint_from_relative_humidity(temperature, relative_humidity): @@ -1163,6 +1168,7 @@ def saturation_mixing_ratio(total_press, temperature): @exporter.export +@solver.register() @preprocess_and_wrap( wrap_like='temperature', broadcast=('pressure', 'temperature', 'dewpoint') @@ -1290,6 +1296,7 @@ def saturation_equivalent_potential_temperature(pressure, temperature): @exporter.export +@solver.register() @preprocess_and_wrap(wrap_like='temperature', broadcast=('temperature', 'mixing_ratio')) @check_units('[temperature]', '[dimensionless]', '[dimensionless]') def virtual_temperature(temperature, mixing_ratio, molecular_weight_ratio=mpconsts.epsilon): @@ -1329,6 +1336,7 @@ def virtual_temperature(temperature, mixing_ratio, molecular_weight_ratio=mpcons @exporter.export +@solver.register() @preprocess_and_wrap( wrap_like='temperature', broadcast=('pressure', 'temperature', 'mixing_ratio') @@ -1375,6 +1383,7 @@ def virtual_potential_temperature(pressure, temperature, mixing_ratio, @exporter.export +@solver.register() @preprocess_and_wrap( wrap_like='temperature', broadcast=('pressure', 'temperature', 'mixing_ratio') @@ -1533,6 +1542,7 @@ def psychrometric_vapor_pressure_wet(pressure, dry_bulb_temperature, wet_bulb_te @exporter.export +@solver.register('mixing_ratio') @preprocess_and_wrap( wrap_like='temperature', broadcast=('pressure', 'temperature', 'relative_humidity') @@ -1581,6 +1591,7 @@ def mixing_ratio_from_relative_humidity(pressure, temperature, relative_humidity @exporter.export +@solver.register('relative_humidity') @preprocess_and_wrap( wrap_like='temperature', broadcast=('pressure', 'temperature', 'mixing_ratio') @@ -1627,6 +1638,7 @@ def relative_humidity_from_mixing_ratio(pressure, temperature, mixing_ratio): @exporter.export +@solver.register('mixing_ratio') @preprocess_and_wrap(wrap_like='specific_humidity') @check_units('[dimensionless]') def mixing_ratio_from_specific_humidity(specific_humidity): @@ -1662,6 +1674,7 @@ def mixing_ratio_from_specific_humidity(specific_humidity): @exporter.export +@solver.register('specific_humidity') @preprocess_and_wrap(wrap_like='mixing_ratio') @check_units('[dimensionless]') def specific_humidity_from_mixing_ratio(mixing_ratio): @@ -1697,6 +1710,7 @@ def specific_humidity_from_mixing_ratio(mixing_ratio): @exporter.export +@solver.register('relative_humidity') @preprocess_and_wrap( wrap_like='temperature', broadcast=('pressure', 'temperature', 'specific_humidity') @@ -2946,6 +2960,7 @@ def brunt_vaisala_period(height, potential_temperature, vertical_dim=0): @exporter.export +@solver.register() @preprocess_and_wrap( wrap_like='temperature', broadcast=('pressure', 'temperature', 'dewpoint') @@ -3051,6 +3066,7 @@ def static_stability(pressure, temperature, vertical_dim=0): @exporter.export +@solver.register('dewpoint') @preprocess_and_wrap( wrap_like='temperature', broadcast=('pressure', 'temperature', 'specific_humdiity') @@ -3188,6 +3204,7 @@ def vertical_velocity(omega, pressure, temperature, mixing_ratio=0): @exporter.export +@solver.register('specific_humidity') @preprocess_and_wrap(wrap_like='dewpoint', broadcast=('dewpoint', 'pressure')) @check_units('[pressure]', '[temperature]') def specific_humidity_from_dewpoint(pressure, dewpoint): diff --git a/tests/calc/test_solver.py b/tests/calc/test_solver.py index c3b4e63382b..bf6d80a35e8 100644 --- a/tests/calc/test_solver.py +++ b/tests/calc/test_solver.py @@ -3,8 +3,12 @@ # SPDX-License-Identifier: BSD-3-Clause """Test the solver.""" import pytest +import xarray as xr +from metpy.calc import solver from metpy.calc.field_solver import Solver +from metpy.testing import assert_almost_equal +from metpy.units import units test_solver = Solver() @@ -79,3 +83,13 @@ def test_failure(): """Test that the correct error results when a value cannot be solved.""" with pytest.raises(ValueError): test_solver.solve({'RH'}, 'Td') + +def test_calculate(): + """Test using the solver results to calculate.""" + temp = xr.DataArray(25, attrs={'units': 'degC'}) + rh = xr.DataArray(80, attrs={'units': 'percent'}) + press = xr.DataArray(994, attrs={'units': 'hPa'}) + data = xr.Dataset({'temperature': temp, 'relative_humidity': rh, 'pressure': press}) + + dewp = solver.calculate(data, 'dewpoint') + assert_almost_equal(dewp, units.Quantity(21.3125, 'degC'), 4) From 7a233f19321b81cd5f203153e98d07df6ffc61ac Mon Sep 17 00:00:00 2001 From: Ryan May Date: Mon, 30 Aug 2021 21:36:47 -0600 Subject: [PATCH 3/9] Hook up the solver inside of declarative --- src/metpy/plots/declarative.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/metpy/plots/declarative.py b/src/metpy/plots/declarative.py index ac7cc01cdb2..8da12279999 100644 --- a/src/metpy/plots/declarative.py +++ b/src/metpy/plots/declarative.py @@ -20,7 +20,7 @@ from ._mpl import TextCollection from .cartopy_utils import import_cartopy from .station_plot import StationPlot -from ..calc import reduce_point_density +from ..calc import reduce_point_density, solver from ..package_tools import Exporter from ..units import units @@ -1083,7 +1083,10 @@ def griddata(self): # Select our particular field of interest if self.field: - data = self.data.metpy.parse_cf(self.field) + if self.field in self.data: + data = self.data.metpy.parse_cf(self.field) + else: + data = solver.calculate(self.data, self.field) elif hasattr(self.data.metpy, 'parse_cf'): # Handles the case where we have a dataset but no specified field raise ValueError('field attribute has not been set.') From e32e6bc9692f2a35fb9480069bee6612a317a246 Mon Sep 17 00:00:00 2001 From: Ryan May Date: Mon, 30 Aug 2021 21:44:20 -0600 Subject: [PATCH 4/9] Fix circular imports --- src/metpy/calc/basic.py | 2 +- src/metpy/calc/thermo.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/metpy/calc/basic.py b/src/metpy/calc/basic.py index 773581d16c1..8afce8130df 100644 --- a/src/metpy/calc/basic.py +++ b/src/metpy/calc/basic.py @@ -16,7 +16,7 @@ import numpy as np from scipy.ndimage import gaussian_filter -from . import solver +from .field_solver import solver from .. import constants as mpconsts from ..package_tools import Exporter from ..units import check_units, masked_array, units diff --git a/src/metpy/calc/thermo.py b/src/metpy/calc/thermo.py index 460e49d327c..8d658008d94 100644 --- a/src/metpy/calc/thermo.py +++ b/src/metpy/calc/thermo.py @@ -10,7 +10,7 @@ import scipy.optimize as so import xarray as xr -from . import solver +from .field_solver import solver from .exceptions import InvalidSoundingError from .tools import (_greater_or_close, _less_or_close, _remove_nans, find_bounding_indices, find_intersections, first_derivative, get_layer) From 2f33f8de028e89de9f6975aad2d6340ac6d1c390 Mon Sep 17 00:00:00 2001 From: Ryan May Date: Mon, 30 Aug 2021 22:46:26 -0600 Subject: [PATCH 5/9] fixup! Hook up the solver inside of declarative --- src/metpy/plots/declarative.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/metpy/plots/declarative.py b/src/metpy/plots/declarative.py index 8da12279999..c2565068d71 100644 --- a/src/metpy/plots/declarative.py +++ b/src/metpy/plots/declarative.py @@ -1086,7 +1086,7 @@ def griddata(self): if self.field in self.data: data = self.data.metpy.parse_cf(self.field) else: - data = solver.calculate(self.data, self.field) + data = solver.calculate(self.data.metpy.parse_cf(), self.field) elif hasattr(self.data.metpy, 'parse_cf'): # Handles the case where we have a dataset but no specified field raise ValueError('field attribute has not been set.') From 42da65ea5612b99f2b29441c5f814bed6de03edb Mon Sep 17 00:00:00 2001 From: Ryan May Date: Mon, 30 Aug 2021 22:54:23 -0600 Subject: [PATCH 6/9] BUG: Fix heat_index with misaligned arrays Need to broadcast temperature and relative_humidity together. --- src/metpy/calc/basic.py | 2 +- tests/calc/test_basic.py | 10 ++++++++++ 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/src/metpy/calc/basic.py b/src/metpy/calc/basic.py index 8afce8130df..cba3c0fc868 100644 --- a/src/metpy/calc/basic.py +++ b/src/metpy/calc/basic.py @@ -221,7 +221,7 @@ def windchill(temperature, speed, face_level_winds=False, mask_undefined=True): @exporter.export @solver.register() -@preprocess_and_wrap(wrap_like='temperature') +@preprocess_and_wrap(broadcast=('temperature', 'relative_humidity'), wrap_like='temperature') @check_units('[temperature]') def heat_index(temperature, relative_humidity, mask_undefined=True): r"""Calculate the Heat Index from the current temperature and relative humidity. diff --git a/tests/calc/test_basic.py b/tests/calc/test_basic.py index e6ae8878e9b..3a75aa54b28 100644 --- a/tests/calc/test_basic.py +++ b/tests/calc/test_basic.py @@ -281,6 +281,16 @@ def test_heat_index_kelvin(): assert_almost_equal(hi.to('degC'), 50.3406 * units.degC, 4) +def test_heat_index_xarray(): + """Test heat_index when working with fields from xarray.""" + temp = xr.DataArray(np.full((1, 4, 2, 3), 35.), attrs={'units': 'degC'}, + dims=('t', 'p', 'y', 'x')) + rh = xr.DataArray(np.full((4, 1, 2, 3), 0.7), dims = ('p', 't', 'y', 'x')) + + hi = heat_index(temp, rh) + assert_almost_equal(hi, units.Quantity(50.3405, 'degC'), 4) + + def test_height_to_geopotential(array_type): """Test conversion from height to geopotential.""" mask = [False, True, False, True] From 462dfa913c2d07643539d955e81c592ee705ca3d Mon Sep 17 00:00:00 2001 From: Ryan May Date: Mon, 30 Aug 2021 22:55:40 -0600 Subject: [PATCH 7/9] Fix solver traversal Need to not include calculated parameters in what we "have" because it's not actually available higher up in the call stack--instead manually remove from the "need" list. This also means we now need the code that detects and breaks call cycles. --- src/metpy/calc/field_solver.py | 10 +++++++--- tests/calc/test_solver.py | 15 ++++++++++++++- 2 files changed, 21 insertions(+), 4 deletions(-) diff --git a/src/metpy/calc/field_solver.py b/src/metpy/calc/field_solver.py index bebf04042a2..1439c17ac23 100644 --- a/src/metpy/calc/field_solver.py +++ b/src/metpy/calc/field_solver.py @@ -30,7 +30,10 @@ def __add__(self, other): raise ValueError(f'{other.steps} already in steps') # Prepend steps so that final path is in proper call order - return Path(other.steps + self.steps, self.have | other.have, self.need | other.need) + # Don't really "have" what's in the new function call, but instead it just needs + # to be removed from what's needed. + return Path(other.steps + self.steps, self.have, + (self.need | other.need) - other.have) def __str__(self): return (f'Path Date: Mon, 30 Aug 2021 22:57:06 -0600 Subject: [PATCH 8/9] Improve name mapping This allows us to automatically calculate heat_index from NARR output. --- src/metpy/calc/field_solver.py | 25 +++++++++++++------------ tests/calc/test_solver.py | 10 ++++++++++ 2 files changed, 23 insertions(+), 12 deletions(-) diff --git a/src/metpy/calc/field_solver.py b/src/metpy/calc/field_solver.py index 1439c17ac23..215d498faea 100644 --- a/src/metpy/calc/field_solver.py +++ b/src/metpy/calc/field_solver.py @@ -3,7 +3,7 @@ # SPDX-License-Identifier: BSD-3-Clause """Solver to automatically calculate derived parameters from a dataset.""" -from collections import deque +from collections import ChainMap, deque import contextlib from inspect import signature, Parameter @@ -43,14 +43,14 @@ def __str__(self): class Solver: - names = {'Tw': 'wet_bulb_temperature', 'Td': 'dewpoint_temperature', - 'dewpoint': 'dewpoint_temperature', 'Tv': 'virtual_temperature', - 'q': 'specific_humidity', 'r': 'mixing_ratio', 'RH': 'relative_humidity', - 'p': 'pressure', 'T': 'temperature'} + names = {'tw': 'wet_bulb_temperature', 'td': 'dewpoint_temperature', + 'dewpoint': 'dewpoint_temperature', 'tv': 'virtual_temperature', + 'q': 'specific_humidity', 'r': 'mixing_ratio', 'rh': 'relative_humidity', + 'p': 'pressure', 't': 'temperature', 'isobaric': 'pressure'} standard_names = {'temperature': 'air_temperature'} - fallback_names = {'temperature': ['temp'], 'pressure': ['P']} + fallback_names = {'temperature': ['temp'], 'pressure': ['P', 'isobaric']} def __init__(self): self._graph = {} @@ -82,23 +82,24 @@ def normalize_names(self, names): return [self.normalize(name) for name in names] def normalize(self, name): - return self.names.get(name, name) + return self.names.get(name.lower(), name.lower()) def _map_func_args(self, func, data): + key_map = {self.normalize(key): key for key in ChainMap(data, data.coords)} for item in self._funcs[func][0]: - if item in data: - yield data[item] + if item in key_map: + yield data[key_map[item]] elif item in self.standard_names: ds = data.filter_by_attrs(standard_name=self.standard_names[item]) yield next(iter(ds)) else: for name in self.fallback_names.get(item, []): - if name in data: - yield data[name] + if name in key_map: + yield data[key_map[name]] def calculate(self, data, name): data = data.copy() - for func in self.solve(set(data), name): + for func in self.solve(set(data) | set(data.coords), name): result = func(*self._map_func_args(func, data)) retname = self._funcs[func][-1] if isinstance(result, tuple): diff --git a/tests/calc/test_solver.py b/tests/calc/test_solver.py index ff5adb8e275..f0b88bccea6 100644 --- a/tests/calc/test_solver.py +++ b/tests/calc/test_solver.py @@ -106,3 +106,13 @@ def test_calculate(): dewp = solver.calculate(data, 'dewpoint') assert_almost_equal(dewp, units.Quantity(21.3125, 'degC'), 4) + + +def test_xarray_mapping(): + temp = xr.DataArray([25], attrs={'units': 'degC'}) + q = xr.DataArray([.019], attrs={'units': 'percent'}) + press = xr.DataArray([994], attrs={'units': 'hPa'}) + data = xr.Dataset({'Temperature': temp, 'Specific_humidity': q}, + coords={'isobaric': press}) + + solver.calculate(data, 'heat_index') From fe809ebe86799065447e6f78a19a73946c1bfdff Mon Sep 17 00:00:00 2001 From: Ryan May Date: Mon, 30 Aug 2021 23:01:31 -0600 Subject: [PATCH 9/9] Simplify management of what's "needed" --- src/metpy/calc/field_solver.py | 10 +--------- 1 file changed, 1 insertion(+), 9 deletions(-) diff --git a/src/metpy/calc/field_solver.py b/src/metpy/calc/field_solver.py index 215d498faea..22d0d34de56 100644 --- a/src/metpy/calc/field_solver.py +++ b/src/metpy/calc/field_solver.py @@ -14,14 +14,6 @@ def __init__(self, steps, have, need): self.have = have self.need = need - @property - def need(self): - return self._need - - @need.setter - def need(self, val): - self._need = {i for i in val if i not in self.have} - def is_complete(self): return not bool(self.need) @@ -33,7 +25,7 @@ def __add__(self, other): # Don't really "have" what's in the new function call, but instead it just needs # to be removed from what's needed. return Path(other.steps + self.steps, self.have, - (self.need | other.need) - other.have) + (self.need | other.need) - (self.have | other.have)) def __str__(self): return (f'Path