diff --git a/oggm/tests/test_utils.py b/oggm/tests/test_utils.py index 2d8094bd3..b0ffb37bd 100644 --- a/oggm/tests/test_utils.py +++ b/oggm/tests/test_utils.py @@ -123,6 +123,16 @@ def test_floatyear_to_date(self): r = utils.floatyear_to_date(yr) assert r == (1998, 2) + # tests for floating point precision + yr = 1 + 1/12 - 1/12 + r = utils.floatyear_to_date(yr) + assert r == (1, 1) + + for i in range(12): + yr = 2000 + r = utils.floatyear_to_date(yr + i / 12) + assert r == (yr, i + 1) + def test_date_to_floatyear(self): r = utils.date_to_floatyear(0, 1) diff --git a/oggm/utils/_funcs.py b/oggm/utils/_funcs.py index 247bccf14..79efd2d27 100644 --- a/oggm/utils/_funcs.py +++ b/oggm/utils/_funcs.py @@ -690,6 +690,19 @@ def floatyear_to_date(yr): if isinstance(yr, xr.DataArray): yr = yr.values + # Ensure yr is a np.array, even for scalar values + yr = np.atleast_1d(yr).astype(np.float64) + + # check if year is inside machine precision to next higher int + yr_ceil = np.ceil(yr) + yr = np.where(np.isclose(yr, + yr_ceil, + rtol=np.finfo(np.float64).eps, + atol=0 + ), + yr_ceil, + yr) + out_y, remainder = np.divmod(yr, 1) out_y = out_y.astype(int) @@ -700,7 +713,7 @@ def floatyear_to_date(yr): np.round(month_exact), np.floor(month_exact)).astype(int)) - if (isinstance(yr, list) or isinstance(yr, np.ndarray)) and len(yr) == 1: + if yr.size == 1: out_y = out_y.item() out_m = out_m.item()