Skip to content

Commit

Permalink
fix typing issues from checking untyped defs, fixes #509
Browse files Browse the repository at this point in the history
  • Loading branch information
kevinsantana11 committed Aug 15, 2024
1 parent 743d71b commit 21ec3ee
Show file tree
Hide file tree
Showing 16 changed files with 133 additions and 97 deletions.
13 changes: 7 additions & 6 deletions clouddrift/ragged.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from collections.abc import Callable, Iterable
from concurrent import futures
from datetime import timedelta
from typing import Any

import numpy as np
import pandas as pd
Expand All @@ -16,12 +17,12 @@ def apply_ragged(
func: callable,
arrays: list[np.ndarray | xr.DataArray] | np.ndarray | xr.DataArray,
rowsize: list[int] | np.ndarray[int] | xr.DataArray,
*args: tuple,
*args: Any,
rows: int | Iterable[int] = None,
axis: int = 0,
executor: futures.Executor = futures.ThreadPoolExecutor(max_workers=None),
**kwargs: dict,
) -> tuple[np.ndarray] | np.ndarray:
**kwargs: Any,
) -> tuple[np.ndarray, np.ndarray] | np.ndarray:
"""Apply a function to a ragged array.
The function ``func`` will be applied to each contiguous row of ``arrays`` as
Expand Down Expand Up @@ -450,9 +451,9 @@ def rowsize_to_index(rowsize: list | np.ndarray | xr.DataArray) -> np.ndarray:


def segment(
x: np.ndarray,
x: list | np.ndarray | xr.DataArray | pd.Series,
tolerance: float | np.timedelta64 | timedelta | pd.Timedelta,
rowsize: np.ndarray[int] = None,
rowsize: np.ndarray[int] | None = None,
) -> np.ndarray[int]:
"""Divide an array into segments based on a tolerance value.
Expand Down Expand Up @@ -789,7 +790,7 @@ def subset(
def unpack(
ragged_array: np.ndarray,
rowsize: np.ndarray[int],
rows: int | Iterable[int] = None,
rows: int | np.int_ | Iterable[int] | None = None,
axis: int = 0,
) -> list[np.ndarray]:
"""Unpack a ragged array into a list of regular arrays.
Expand Down
27 changes: 15 additions & 12 deletions clouddrift/sphere.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
"""

import warnings
from typing import TypeVar

import numpy as np
import xarray as xr
Expand Down Expand Up @@ -207,7 +208,10 @@ def bearing(


def position_from_distance_and_bearing(
lon: float, lat: float, distance: float, bearing: float
lon: float | np.ndarray,
lat: float | np.ndarray,
distance: float | np.ndarray,
bearing: float | np.ndarray,
) -> tuple[float, float]:
"""Return elementwise new position in degrees from arrays of latitude and
longitude in degrees, distance in meters, and bearing in radians, based on
Expand Down Expand Up @@ -660,13 +664,12 @@ def cartesian_to_spherical(
return lon, lat


T = TypeVar("T", bound=float | np.ndarray)


def cartesian_to_tangentplane(
u: float | np.ndarray,
v: float | np.ndarray,
w: float | np.ndarray,
longitude: float | np.ndarray,
latitude: float | np.ndarray,
) -> tuple[float, float] | tuple[np.ndarray, np.ndarray]:
u: T, v: T, w: T, longitude: T, latitude: T
) -> tuple[T, T]:
"""
Project a three-dimensional Cartesian vector on a plane tangent to
a spherical Earth.
Expand Down Expand Up @@ -725,12 +728,12 @@ def cartesian_to_tangentplane(
return u_projected, v_projected


T = TypeVar("T", bound=float | np.ndarray)


def tangentplane_to_cartesian(
up: float | np.ndarray,
vp: float | np.ndarray,
longitude: float | np.ndarray,
latitude: float | np.ndarray,
) -> tuple[float, float, float] | tuple[np.ndarray, np.ndarray, np.ndarray]:
up: T, vp: T, longitude: T, latitude: T
) -> tuple[T, T, T]:
"""
Return the three-dimensional Cartesian components of a vector contained in
a plane tangent to a spherical Earth.
Expand Down
13 changes: 13 additions & 0 deletions clouddrift/typing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from datetime import timedelta

import numpy as np
import pandas as pd
import xarray as xr

_SupportedArrayTypes = list | np.ndarray | xr.DataArray | pd.Series
_ArrayTypes = _SupportedArrayTypes

_SupportedTimeDeltaTypes = pd.Timedelta | timedelta | np.timedelta64
_TimeDeltaTypes = _SupportedTimeDeltaTypes

__all__ = ["_ArrayTypes", "_TimeDeltaTypes"]
6 changes: 3 additions & 3 deletions clouddrift/wavelet.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,9 +344,9 @@ def wavelet_transform(

def morse_wavelet(
length: int,
gamma: float,
beta: float,
radian_frequency: np.ndarray,
gamma: float | np.ndarray,
beta: float | np.ndarray,
radian_frequency: float | np.ndarray,
order: int = 1,
normalization: str = "bandpass",
) -> tuple[np.ndarray, np.ndarray]:
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ select = ["E4", "E7", "E9", "F", "I"]
[tool.mypy]
python_version = "3.10"
follow_imports = "normal"
check_untyped_defs = true
files = [
"clouddrift/**/*.py",
"tests/**/*.py",
Expand Down
4 changes: 1 addition & 3 deletions tests/adapters/gdp1h_integ_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,5 @@ def test_load_subset_and_create_aggregate(self):

@classmethod
def tearDownClass(cls):
[
for dir in [gdp1h.GDP_TMP_PATH, gdp1h.GDP_TMP_PATH_EXPERIMENTAL]:
shutil.rmtree(dir)
for dir in [gdp1h.GDP_TMP_PATH, gdp1h.GDP_TMP_PATH_EXPERIMENTAL]
]
2 changes: 1 addition & 1 deletion tests/adapters/gdp6h_integ_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,4 +24,4 @@ def test_load_subset_and_create_aggregate(self):

@classmethod
def tearDownClass(cls):
[shutil.rmtree(dir) for dir in [gdp6h.GDP_TMP_PATH]]
shutil.rmtree(gdp6h.GDP_TMP_PATH)
2 changes: 1 addition & 1 deletion tests/adapters/hurdat2_integ_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,4 +39,4 @@ def test_conversion(self):

@classmethod
def tearDownClass(cls):
[shutil.rmtree(dir) for dir in [hurdat2._DEFAULT_FILE_PATH]]
shutil.rmtree(hurdat2._DEFAULT_FILE_PATH)
3 changes: 2 additions & 1 deletion tests/adapters/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,5 @@ def __enter__(self) -> Sequence[Mock]:
return [p.start() for p in self._patches]

def __exit__(self, *_):
[p.stop() for p in self._patches]
for p in self._patches:
p.stop()
12 changes: 5 additions & 7 deletions tests/adapters/utils_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,7 @@ def test_forgo_download_no_update(self):
),
]
) as _:
utils._download_with_progress(
"some.url.com", "./some/path/existing-file.nc", 0, False
)
utils._download_with_progress("some.url.com", Mock(), 0, False)
self.requests_mock.get.assert_not_called()

def test_download_new_update(self):
Expand Down Expand Up @@ -92,9 +90,9 @@ def test_progress_mechanism_disabled_files(self):
"""
mocked_futures = [self.gen_future_mock() for _ in range(0, 3)]
download_requests = [
("src0", "dst", None),
("src1", "dst", None),
("src2", "dst", None),
("src0", "dst"),
("src1", "dst"),
("src2", "dst"),
]

tpe_mock = Mock()
Expand Down Expand Up @@ -127,7 +125,7 @@ def test_progress_mechanism_enabled_files(self):
"""

mocked_futures = [self.gen_future_mock() for _ in range(0, 21)]
download_requests = [("src0", "dst", None) for _ in range(0, 21)]
download_requests = [("src0", "dst") for _ in range(0, 21)]

tpe_mock = Mock()
tpe_mock.__enter__ = Mock(return_value=tpe_mock)
Expand Down
2 changes: 1 addition & 1 deletion tests/datasets_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def test_glad_subset_and_apply_ragged_work(self):
)
self.assertTrue(ds_sub)
mean_lon = apply_ragged(np.mean, [ds_sub.longitude], ds_sub.rowsize)
self.assertTrue(mean_lon.size == 2)
self.assertTrue(len(mean_lon) == 2)

def test_spotters_opens(self):
with datasets.spotters() as ds:
Expand Down
82 changes: 49 additions & 33 deletions tests/ragged_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
unpack,
)
from clouddrift.raggedarray import RaggedArray
from clouddrift.typing import _ArrayTypes, _TimeDeltaTypes

if __name__ == "__main__":
unittest.main()
Expand Down Expand Up @@ -211,7 +212,9 @@ def test_prune(self):
rowsize = [3, 2, 4]
minimum = 3

for data in [x, np.array(x), pd.Series(data=x), xr.DataArray(data=x)]:
for data in list[_ArrayTypes](
[x, np.array(x), pd.Series(data=x), xr.DataArray(data=x)]
):
x_new, rowsize_new = prune(data, rowsize, minimum)
self.assertTrue(isinstance(x_new, np.ndarray))
self.assertTrue(isinstance(rowsize_new, np.ndarray))
Expand All @@ -223,7 +226,10 @@ def test_prune_all_longer(self):
rowsize = [3, 2, 4]
minimum = 1

for data in [x, np.array(x), pd.Series(data=x), xr.DataArray(data=x)]:
data: _ArrayTypes
for data in list[_ArrayTypes](
[x, np.array(x), pd.Series(data=x), xr.DataArray(data=x)]
):
x_new, rowsize_new = prune(data, rowsize, minimum)
np.testing.assert_equal(x_new, data)
np.testing.assert_equal(rowsize_new, rowsize)
Expand All @@ -233,7 +239,10 @@ def test_prune_all_smaller(self):
rowsize = [3, 2, 4]
minimum = 5

for data in [x, np.array(x), pd.Series(data=x), xr.DataArray(data=x)]:
data: _ArrayTypes
for data in list[_ArrayTypes](
[x, np.array(x), pd.Series(data=x), xr.DataArray(data=x)]
):
x_new, rowsize_new = prune(data, rowsize, minimum)
np.testing.assert_equal(x_new, np.array([]))
np.testing.assert_equal(rowsize_new, np.array([]))
Expand Down Expand Up @@ -266,17 +275,23 @@ def test_prune_keep_nan(self):
rowsize = [3, 2, 4]
minimum = 3

for data in [x, np.array(x), pd.Series(data=x), xr.DataArray(data=x)]:
data: _ArrayTypes
for data in list[_ArrayTypes](
[x, np.array(x), pd.Series(data=x), xr.DataArray(data=x)]
):
x_new, rowsize_new = prune(data, rowsize, minimum)
np.testing.assert_equal(x_new, [1, 2, np.nan, 1, 2, np.nan, 4])
np.testing.assert_equal(rowsize_new, [3, 4])

def test_prune_empty(self):
x = []
rowsize = []
x: list[int] = []
rowsize: list[int] = []
minimum = 3

for data in [x, np.array(x), pd.Series(data=x), xr.DataArray(data=x)]:
data: _ArrayTypes
for data in list[_ArrayTypes](
[x, np.array(x), pd.Series(data=x), xr.DataArray(data=x)]
):
with self.assertRaises(IndexError):
x_new, rowsize_new = prune(data, rowsize, minimum)

Expand All @@ -285,7 +300,10 @@ def test_print_incompatible_rowsize(self):
rowsize = [3, 3]
minimum = 3

for data in [x, np.array(x), pd.Series(data=x), xr.DataArray(data=x)]:
data: _ArrayTypes
for data in list[_ArrayTypes](
[x, np.array(x), pd.Series(data=x), xr.DataArray(data=x)]
):
with self.assertRaises(ValueError):
x_new, rowsize_new = prune(data, rowsize, minimum)

Expand All @@ -304,9 +322,7 @@ def test_segment(self):
def test_segment_zero_tolerance(self):
x = [1, 2, 2, 3, 3, 3, 4, 4, 4, 4]
tol = 0
self.assertIsNone(
np.testing.assert_equal(segment(x, tol), np.array([1, 2, 3, 4]))
)
np.testing.assert_equal(segment(x, tol), np.array([1, 2, 3, 4]))

def test_segment_negative_tolerance(self):
x = [0, 1, 1, 1, 2, 0, 3, 3, 3, 4]
Expand All @@ -316,7 +332,7 @@ def test_segment_negative_tolerance(self):
def test_segment_rowsize(self):
x = [0, 1, 1, 1, 2, 2, 3, 3, 3, 3, 4]
tol = 0.5
rowsize = [6, 5]
rowsize = np.array([6, 5])
segment_sizes = segment(x, tol, rowsize)
self.assertTrue(isinstance(segment_sizes, np.ndarray))
self.assertTrue(np.all(segment_sizes == np.array([1, 3, 2, 4, 1])))
Expand All @@ -327,9 +343,9 @@ def test_segment_positive_and_negative_tolerance(self):
self.assertTrue(np.all(segment_sizes == np.array([2, 2, 2, 2])))

def test_segment_rowsize_raises(self):
x = [0, 1, 2, 3]
x = np.array([0, 1, 2, 3])
tol = 0.5
rowsize = [1, 2] # rowsize is too short
rowsize = np.array([1, 2]) # rowsize is too short
with self.assertRaises(ValueError):
segment(x, tol, rowsize)

Expand All @@ -341,10 +357,11 @@ def test_segments_datetime(self):
datetime(2023, 2, 1),
datetime(2023, 2, 2),
]
for tol in [pd.Timedelta("1 day"), timedelta(days=1), np.timedelta64(1, "D")]:
self.assertIsNone(
np.testing.assert_equal(segment(x, tol), np.array([3, 2]))
)
tol: _TimeDeltaTypes
for tol in list[_TimeDeltaTypes](
[pd.Timedelta("1 day"), timedelta(days=1), np.timedelta64(1, "D")]
):
np.testing.assert_equal(segment(x, tol), np.array([3, 2]))

def test_segments_numpy(self):
x = np.array(
Expand All @@ -356,17 +373,19 @@ def test_segments_numpy(self):
np.datetime64("2023-02-02"),
]
)
for tol in [pd.Timedelta("1 day"), timedelta(days=1), np.timedelta64(1, "D")]:
self.assertIsNone(
np.testing.assert_equal(segment(x, tol), np.array([3, 2]))
)
for tol in list[_TimeDeltaTypes](
[pd.Timedelta("1 day"), timedelta(days=1), np.timedelta64(1, "D")]
):
np.testing.assert_equal(segment(x, tol), np.array([3, 2]))

def test_segments_pandas(self):
x = pd.to_datetime(["1/1/2023", "1/2/2023", "1/3/2023", "2/1/2023", "2/2/2023"])
for tol in [pd.Timedelta("1 day"), timedelta(days=1), np.timedelta64(1, "D")]:
self.assertIsNone(
np.testing.assert_equal(segment(x, tol), np.array([3, 2]))
)
x: pd.Series = pd.to_datetime(
pd.Series(["1/1/2023", "1/2/2023", "1/3/2023", "2/1/2023", "2/2/2023"])
)
for tol in list[_TimeDeltaTypes](
[pd.Timedelta("1 day"), timedelta(days=1), np.timedelta64(1, "D")]
):
np.testing.assert_equal(segment(x, tol), np.array([3, 2]))


class ragged_to_regular_tests(unittest.TestCase):
Expand Down Expand Up @@ -503,8 +522,8 @@ def test_with_axis(self):
# ragged axis 1 is th same as applying it to the transpose over ragged
# axis 0.
rowsize = [1, 1]
y0 = apply_ragged(func, x.T, rowsize, axis=0)
y1 = apply_ragged(func, x, rowsize, axis=1)
y0, _ = apply_ragged(func, x.T, rowsize, axis=0)
y1, _ = apply_ragged(func, x, rowsize, axis=1)
self.assertTrue(np.all(y0 == y1.T))

# Test that axis=1 works with reduction over the non-ragged axis.
Expand Down Expand Up @@ -570,10 +589,7 @@ def test_bad_rowsize_raises(self):
with self.assertRaises(ValueError):
for use_threads in [True, False]:
apply_ragged(
lambda x: x**2,
np.array([1, 2, 3, 4]),
[2],
use_threads=use_threads,
lambda x: x**2, np.array([1, 2, 3, 4]), [2], use_threads=use_threads
)


Expand Down
Loading

0 comments on commit 21ec3ee

Please sign in to comment.