diff --git a/src/metpy/calc/__init__.py b/src/metpy/calc/__init__.py index 736a35f0913..68eb6a0e92d 100644 --- a/src/metpy/calc/__init__.py +++ b/src/metpy/calc/__init__.py @@ -6,6 +6,7 @@ from .basic import * # noqa: F403 from .cross_sections import * # noqa: F403 from .exceptions import * # noqa: F403 +from .field_solver import solver from .indices import * # noqa: F403 from .kinematics import * # noqa: F403 from .thermo import * # noqa: F403 @@ -13,7 +14,8 @@ from .turbulence import * # noqa: F403 from ..package_tools import set_module -__all__ = basic.__all__[:] # pylint: disable=undefined-variable +__all__ = ['solver'] +__all__.extend(basic.__all__) # pylint: disable=undefined-variable __all__.extend(cross_sections.__all__) # pylint: disable=undefined-variable __all__.extend(indices.__all__) # pylint: disable=undefined-variable __all__.extend(kinematics.__all__) # pylint: disable=undefined-variable diff --git a/src/metpy/calc/basic.py b/src/metpy/calc/basic.py index 7c53d25be86..cba3c0fc868 100644 --- a/src/metpy/calc/basic.py +++ b/src/metpy/calc/basic.py @@ -16,6 +16,7 @@ import numpy as np from scipy.ndimage import gaussian_filter +from .field_solver import solver from .. import constants as mpconsts from ..package_tools import Exporter from ..units import check_units, masked_array, units @@ -30,6 +31,7 @@ @exporter.export +@solver.register() @preprocess_and_wrap(wrap_like='u') @check_units('[speed]', '[speed]') def wind_speed(u, v): @@ -57,6 +59,7 @@ def wind_speed(u, v): @exporter.export +@solver.register() @preprocess_and_wrap(wrap_like='u') @check_units('[speed]', '[speed]') def wind_direction(u, v, convention='from'): @@ -112,6 +115,7 @@ def wind_direction(u, v, convention='from'): @exporter.export +@solver.register('u', 'v') @preprocess_and_wrap(wrap_like=('speed', 'speed')) @check_units('[speed]') def wind_components(speed, wind_direction): @@ -153,6 +157,7 @@ def wind_components(speed, wind_direction): @exporter.export +@solver.register('temperature', 'wind_speed') @preprocess_and_wrap(wrap_like='temperature') @check_units(temperature='[temperature]', speed='[speed]') def windchill(temperature, speed, face_level_winds=False, mask_undefined=True): @@ -215,7 +220,8 @@ def windchill(temperature, speed, face_level_winds=False, mask_undefined=True): @exporter.export -@preprocess_and_wrap(wrap_like='temperature') +@solver.register() +@preprocess_and_wrap(broadcast=('temperature', 'relative_humidity'), wrap_like='temperature') @check_units('[temperature]') def heat_index(temperature, relative_humidity, mask_undefined=True): r"""Calculate the Heat Index from the current temperature and relative humidity. diff --git a/src/metpy/calc/field_solver.py b/src/metpy/calc/field_solver.py new file mode 100644 index 00000000000..22d0d34de56 --- /dev/null +++ b/src/metpy/calc/field_solver.py @@ -0,0 +1,127 @@ +# Copyright (c) 2021 MetPy Developers. +# Distributed under the terms of the BSD 3-Clause License. +# SPDX-License-Identifier: BSD-3-Clause +"""Solver to automatically calculate derived parameters from a dataset.""" + +from collections import ChainMap, deque +import contextlib +from inspect import signature, Parameter + + +class Path: + def __init__(self, steps, have, need): + self.steps = steps + self.have = have + self.need = need + + def is_complete(self): + return not bool(self.need) + + def __add__(self, other): + if any(f in set(self.steps) for f in other.steps): + raise ValueError(f'{other.steps} already in steps') + + # Prepend steps so that final path is in proper call order + # Don't really "have" what's in the new function call, but instead it just needs + # to be removed from what's needed. + return Path(other.steps + self.steps, self.have, + (self.need | other.need) - (self.have | other.have)) + + def __str__(self): + return (f'Path') + + __repr__ = __str__ + + +class Solver: + names = {'tw': 'wet_bulb_temperature', 'td': 'dewpoint_temperature', + 'dewpoint': 'dewpoint_temperature', 'tv': 'virtual_temperature', + 'q': 'specific_humidity', 'r': 'mixing_ratio', 'rh': 'relative_humidity', + 'p': 'pressure', 't': 'temperature', 'isobaric': 'pressure'} + + standard_names = {'temperature': 'air_temperature'} + + fallback_names = {'temperature': ['temp'], 'pressure': ['P', 'isobaric']} + + def __init__(self): + self._graph = {} + self._funcs = {} + + def register(self, *args, inputs=None): + def dec(func): + nonlocal inputs + nonlocal args + if inputs is None: + funcsig = signature(func) + inputs = [name for name, param in funcsig.parameters.items() if + param.default is Parameter.empty] + + if not args: + args = (func.__name__,) + + normed_returns = self.normalize_names(args) + normed_inputs = self.normalize_names(inputs) + path = Path([func], set(normed_returns), set(normed_inputs)) + self._funcs[func] = (normed_inputs, normed_returns) + for ret in normed_returns: + self._graph.setdefault(ret, []).append(path) + return func + + return dec + + def normalize_names(self, names): + return [self.normalize(name) for name in names] + + def normalize(self, name): + return self.names.get(name.lower(), name.lower()) + + def _map_func_args(self, func, data): + key_map = {self.normalize(key): key for key in ChainMap(data, data.coords)} + for item in self._funcs[func][0]: + if item in key_map: + yield data[key_map[item]] + elif item in self.standard_names: + ds = data.filter_by_attrs(standard_name=self.standard_names[item]) + yield next(iter(ds)) + else: + for name in self.fallback_names.get(item, []): + if name in key_map: + yield data[key_map[name]] + + def calculate(self, data, name): + data = data.copy() + for func in self.solve(set(data) | set(data.coords), name): + result = func(*self._map_func_args(func, data)) + retname = self._funcs[func][-1] + if isinstance(result, tuple): + for name, val in zip(retname, result): + data[name] = val + else: + data[retname] = result + + return data[self.normalize(name)] + + def solve(self, have, want): + # Using deque as a FIFO queue by pushing at one end and popping + # from the other--this makes this a Breadth-First Search + options = deque([Path([], set(self.normalize_names(have)), {self.normalize(want)})]) + while options: + path = options.popleft() + # If calculation path is complete, return the steps + if path.is_complete(): + return path.steps + else: + # Otherwise grab one of the remaining needs and + # add all methods for calculating to the current steps + # and make them options to consider + item = path.need.pop() + for trial_step in self._graph.get(item, ()): + # ValueError gets thrown if we try to repeat a function call + with contextlib.suppress(ValueError): + options.append(path + trial_step) + + raise ValueError(f'Unable to calculate {want} from {have}') + + +solver = Solver() diff --git a/src/metpy/calc/thermo.py b/src/metpy/calc/thermo.py index 35053e8ef68..8d658008d94 100644 --- a/src/metpy/calc/thermo.py +++ b/src/metpy/calc/thermo.py @@ -10,6 +10,7 @@ import scipy.optimize as so import xarray as xr +from .field_solver import solver from .exceptions import InvalidSoundingError from .tools import (_greater_or_close, _less_or_close, _remove_nans, find_bounding_indices, find_intersections, first_derivative, get_layer) @@ -26,6 +27,7 @@ @exporter.export +@solver.register('relative_humidity') @preprocess_and_wrap(wrap_like='temperature', broadcast=('temperature', 'dewpoint')) @check_units('[temperature]', '[temperature]') def relative_humidity_from_dewpoint(temperature, dewpoint): @@ -101,6 +103,7 @@ def exner_function(pressure, reference_pressure=mpconsts.P0): @exporter.export +@solver.register() @preprocess_and_wrap(wrap_like='temperature', broadcast=('pressure', 'temperature')) @check_units('[pressure]', '[temperature]') def potential_temperature(pressure, temperature): @@ -148,6 +151,7 @@ def potential_temperature(pressure, temperature): broadcast=('pressure', 'potential_temperature') ) @check_units('[pressure]', '[temperature]') +@solver.register('temperature') def temperature_from_potential_temperature(pressure, potential_temperature): r"""Calculate the temperature from a given potential temperature. @@ -1011,6 +1015,7 @@ def saturation_vapor_pressure(temperature): @exporter.export +@solver.register('dewpoint') @preprocess_and_wrap(wrap_like='temperature', broadcast=('temperature', 'relative_humidity')) @check_units('[temperature]', '[dimensionless]') def dewpoint_from_relative_humidity(temperature, relative_humidity): @@ -1163,6 +1168,7 @@ def saturation_mixing_ratio(total_press, temperature): @exporter.export +@solver.register() @preprocess_and_wrap( wrap_like='temperature', broadcast=('pressure', 'temperature', 'dewpoint') @@ -1290,6 +1296,7 @@ def saturation_equivalent_potential_temperature(pressure, temperature): @exporter.export +@solver.register() @preprocess_and_wrap(wrap_like='temperature', broadcast=('temperature', 'mixing_ratio')) @check_units('[temperature]', '[dimensionless]', '[dimensionless]') def virtual_temperature(temperature, mixing_ratio, molecular_weight_ratio=mpconsts.epsilon): @@ -1329,6 +1336,7 @@ def virtual_temperature(temperature, mixing_ratio, molecular_weight_ratio=mpcons @exporter.export +@solver.register() @preprocess_and_wrap( wrap_like='temperature', broadcast=('pressure', 'temperature', 'mixing_ratio') @@ -1375,6 +1383,7 @@ def virtual_potential_temperature(pressure, temperature, mixing_ratio, @exporter.export +@solver.register() @preprocess_and_wrap( wrap_like='temperature', broadcast=('pressure', 'temperature', 'mixing_ratio') @@ -1533,6 +1542,7 @@ def psychrometric_vapor_pressure_wet(pressure, dry_bulb_temperature, wet_bulb_te @exporter.export +@solver.register('mixing_ratio') @preprocess_and_wrap( wrap_like='temperature', broadcast=('pressure', 'temperature', 'relative_humidity') @@ -1581,6 +1591,7 @@ def mixing_ratio_from_relative_humidity(pressure, temperature, relative_humidity @exporter.export +@solver.register('relative_humidity') @preprocess_and_wrap( wrap_like='temperature', broadcast=('pressure', 'temperature', 'mixing_ratio') @@ -1627,6 +1638,7 @@ def relative_humidity_from_mixing_ratio(pressure, temperature, mixing_ratio): @exporter.export +@solver.register('mixing_ratio') @preprocess_and_wrap(wrap_like='specific_humidity') @check_units('[dimensionless]') def mixing_ratio_from_specific_humidity(specific_humidity): @@ -1662,6 +1674,7 @@ def mixing_ratio_from_specific_humidity(specific_humidity): @exporter.export +@solver.register('specific_humidity') @preprocess_and_wrap(wrap_like='mixing_ratio') @check_units('[dimensionless]') def specific_humidity_from_mixing_ratio(mixing_ratio): @@ -1697,6 +1710,7 @@ def specific_humidity_from_mixing_ratio(mixing_ratio): @exporter.export +@solver.register('relative_humidity') @preprocess_and_wrap( wrap_like='temperature', broadcast=('pressure', 'temperature', 'specific_humidity') @@ -2946,6 +2960,7 @@ def brunt_vaisala_period(height, potential_temperature, vertical_dim=0): @exporter.export +@solver.register() @preprocess_and_wrap( wrap_like='temperature', broadcast=('pressure', 'temperature', 'dewpoint') @@ -3051,6 +3066,7 @@ def static_stability(pressure, temperature, vertical_dim=0): @exporter.export +@solver.register('dewpoint') @preprocess_and_wrap( wrap_like='temperature', broadcast=('pressure', 'temperature', 'specific_humdiity') @@ -3188,6 +3204,7 @@ def vertical_velocity(omega, pressure, temperature, mixing_ratio=0): @exporter.export +@solver.register('specific_humidity') @preprocess_and_wrap(wrap_like='dewpoint', broadcast=('dewpoint', 'pressure')) @check_units('[pressure]', '[temperature]') def specific_humidity_from_dewpoint(pressure, dewpoint): diff --git a/src/metpy/plots/declarative.py b/src/metpy/plots/declarative.py index ac7cc01cdb2..c2565068d71 100644 --- a/src/metpy/plots/declarative.py +++ b/src/metpy/plots/declarative.py @@ -20,7 +20,7 @@ from ._mpl import TextCollection from .cartopy_utils import import_cartopy from .station_plot import StationPlot -from ..calc import reduce_point_density +from ..calc import reduce_point_density, solver from ..package_tools import Exporter from ..units import units @@ -1083,7 +1083,10 @@ def griddata(self): # Select our particular field of interest if self.field: - data = self.data.metpy.parse_cf(self.field) + if self.field in self.data: + data = self.data.metpy.parse_cf(self.field) + else: + data = solver.calculate(self.data.metpy.parse_cf(), self.field) elif hasattr(self.data.metpy, 'parse_cf'): # Handles the case where we have a dataset but no specified field raise ValueError('field attribute has not been set.') diff --git a/tests/calc/test_basic.py b/tests/calc/test_basic.py index e6ae8878e9b..3a75aa54b28 100644 --- a/tests/calc/test_basic.py +++ b/tests/calc/test_basic.py @@ -281,6 +281,16 @@ def test_heat_index_kelvin(): assert_almost_equal(hi.to('degC'), 50.3406 * units.degC, 4) +def test_heat_index_xarray(): + """Test heat_index when working with fields from xarray.""" + temp = xr.DataArray(np.full((1, 4, 2, 3), 35.), attrs={'units': 'degC'}, + dims=('t', 'p', 'y', 'x')) + rh = xr.DataArray(np.full((4, 1, 2, 3), 0.7), dims = ('p', 't', 'y', 'x')) + + hi = heat_index(temp, rh) + assert_almost_equal(hi, units.Quantity(50.3405, 'degC'), 4) + + def test_height_to_geopotential(array_type): """Test conversion from height to geopotential.""" mask = [False, True, False, True] diff --git a/tests/calc/test_solver.py b/tests/calc/test_solver.py new file mode 100644 index 00000000000..f0b88bccea6 --- /dev/null +++ b/tests/calc/test_solver.py @@ -0,0 +1,118 @@ +# Copyright (c) 2021 MetPy Developers. +# Distributed under the terms of the BSD 3-Clause License. +# SPDX-License-Identifier: BSD-3-Clause +"""Test the solver.""" +import pytest +import xarray as xr + +from metpy.calc import solver +from metpy.calc.field_solver import Solver +from metpy.testing import assert_almost_equal +from metpy.units import units + +test_solver = Solver() + +@test_solver.register('Td') +def dewpoint_from_relative_humidity(temperature, relative_humidity): + pass + + +@test_solver.register('RH') +def relative_humidity_from_dewpoint(temperature, dewpoint): + pass + + +@test_solver.register('RH') +def relative_humidity_from_specific_humidity(pressure, temperature, specific_humidity): + pass + + +@test_solver.register('r') +def mixing_ratio_from_relative_humidity(pressure, temperature, relative_humidity): + pass + + +@test_solver.register('Tv') +def virtual_temperature(temperature, mixing_ratio, molecular_weight_ratio=5): + pass + + +@test_solver.register('RH') +def relative_humidity_from_mixing_ratio(pressure, temperature, mixing_ratio): + pass + + +@test_solver.register('r') +def mixing_ratio_from_specific_humidity(specific_humidity): + pass + + +@test_solver.register('q') +def specific_humidity_from_mixing_ratio(mixing_ratio): + pass + + +@test_solver.register('rho') +def density(pressure, temperature, mixing_ratio): + pass + + +@test_solver.register('Tw') +def wet_bulb_temperature(pressure, temperature, dewpoint): + pass + + +@test_solver.register('u', 'v') +def wind_components(wind_speed, wind_direction): + pass + + +@test_solver.register() +def vorticity(u, v): + pass + + +@test_solver.register() +def heat_index(temperature, relative_humidity): + pass + + +@pytest.mark.parametrize(['inputs', 'want', 'truth'], [ + ({'T', 'RH'}, 'Td', [dewpoint_from_relative_humidity]), + ({'T', 'p', 'q'}, 'Td', [relative_humidity_from_specific_humidity, + dewpoint_from_relative_humidity]), + ({'T', 'p', 'q'}, 'Tv', [mixing_ratio_from_specific_humidity, virtual_temperature]), + ({'p', 'T', 'RH'}, 'rho', [mixing_ratio_from_relative_humidity, density]), + ({'p', 'T', 'RH'}, 'Tw', [dewpoint_from_relative_humidity, wet_bulb_temperature]), + ({'wind_speed', 'wind_direction'}, 'vorticity', [wind_components, vorticity]), + ({'T', 'p', 'q'}, 'heat_index', [relative_humidity_from_specific_humidity, heat_index]), +]) +def test_solutions(inputs, want, truth): + """Test that the proper sequence of calculations is found.""" + assert test_solver.solve(inputs, want) == truth + + +def test_failure(): + """Test that the correct error results when a value cannot be solved.""" + with pytest.raises(ValueError): + test_solver.solve({'RH'}, 'Td') + +def test_calculate(): + """Test using the solver results to calculate.""" + temp = xr.DataArray(25, attrs={'units': 'degC'}) + rh = xr.DataArray(80, attrs={'units': 'percent'}) + press = xr.DataArray(994, attrs={'units': 'hPa'}) + data = xr.Dataset({'temperature': temp, 'relative_humidity': rh, 'pressure': press}) + + dewp = solver.calculate(data, 'dewpoint') + assert_almost_equal(dewp, units.Quantity(21.3125, 'degC'), 4) + + +def test_xarray_mapping(): + temp = xr.DataArray([25], attrs={'units': 'degC'}) + q = xr.DataArray([.019], attrs={'units': 'percent'}) + press = xr.DataArray([994], attrs={'units': 'hPa'}) + data = xr.Dataset({'Temperature': temp, 'Specific_humidity': q}, + coords={'isobaric': press}) + + solver.calculate(data, 'heat_index')