Skip to content

Commit

Permalink
replace type comparison with isinstance (#491)
Browse files Browse the repository at this point in the history
  • Loading branch information
philippemiron authored Jul 16, 2024
1 parent dddc803 commit 079ae36
Show file tree
Hide file tree
Showing 4 changed files with 18 additions and 18 deletions.
6 changes: 3 additions & 3 deletions clouddrift/ragged.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def apply_ragged(
If empty ``arrays``.
"""
# make sure the arrays is iterable
if type(arrays) not in [list, tuple]:
if not isinstance(arrays, (list, tuple)):
arrays = [arrays]
# validate rowsize
for arr in arrays:
Expand Down Expand Up @@ -520,10 +520,10 @@ def segment(
"""

# for compatibility with datetime list or np.timedelta64 arrays
if type(tolerance) in [np.timedelta64, timedelta]:
if isinstance(tolerance, (np.timedelta64, timedelta)):
tolerance = pd.Timedelta(tolerance)

if type(tolerance) == pd.Timedelta:
if isinstance(tolerance, pd.Timedelta):
positive_tol = tolerance >= pd.Timedelta("0 seconds")
else:
positive_tol = tolerance >= 0
Expand Down
16 changes: 8 additions & 8 deletions clouddrift/sphere.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,19 +103,19 @@ def distance(
# Input coordinates are in degrees; convert to radians.
# If any of the input arrays are xr.DataArray, extract the values first
# because Xarray enforces alignment between coordinates.
if type(lat1) is xr.DataArray:
if isinstance(lat1, xr.DataArray):
lat1_rad = np.deg2rad(lat1.values)
else:
lat1_rad = np.deg2rad(lat1)
if type(lon1) is xr.DataArray:
if isinstance(lon1, xr.DataArray):
lon1_rad = np.deg2rad(lon1.values)
else:
lon1_rad = np.deg2rad(lon1)
if type(lat2) is xr.DataArray:
if isinstance(lat2, xr.DataArray):
lat2_rad = np.deg2rad(lat2.values)
else:
lat2_rad = np.deg2rad(lat2)
if type(lon2) is xr.DataArray:
if isinstance(lon2, xr.DataArray):
lon2_rad = np.deg2rad(lon2.values)
else:
lon2_rad = np.deg2rad(lon2)
Expand Down Expand Up @@ -178,19 +178,19 @@ def bearing(
# Input coordinates are in degrees; convert to radians.
# If any of the input arrays are xr.DataArray, extract the values first
# because Xarray enforces alignment between coordinates.
if type(lat1) is xr.DataArray:
if isinstance(lat1, xr.DataArray):
lat1_rad = np.deg2rad(lat1.values)
else:
lat1_rad = np.deg2rad(lat1)
if type(lon1) is xr.DataArray:
if isinstance(lon1, xr.DataArray):
lon1_rad = np.deg2rad(lon1.values)
else:
lon1_rad = np.deg2rad(lon1)
if type(lat2) is xr.DataArray:
if isinstance(lat2, xr.DataArray):
lat2_rad = np.deg2rad(lat2.values)
else:
lat2_rad = np.deg2rad(lat2)
if type(lon2) is xr.DataArray:
if isinstance(lon2, xr.DataArray):
lon2_rad = np.deg2rad(lon2.values)
else:
lon2_rad = np.deg2rad(lon2)
Expand Down
12 changes: 6 additions & 6 deletions tests/ragged_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,8 +212,8 @@ def test_prune(self):

for data in [x, np.array(x), pd.Series(data=x), xr.DataArray(data=x)]:
x_new, rowsize_new = prune(data, rowsize, minimum)
self.assertTrue(type(x_new) is np.ndarray)
self.assertTrue(type(rowsize_new) is np.ndarray)
self.assertTrue(isinstance(x_new, np.ndarray))
self.assertTrue(isinstance(rowsize_new, np.ndarray))
np.testing.assert_equal(x_new, [1, 2, 3, 1, 2, 3, 4])
np.testing.assert_equal(rowsize_new, [3, 4])

Expand Down Expand Up @@ -293,7 +293,7 @@ class segment_tests(unittest.TestCase):
def test_segment(self):
x = [0, 1, 1, 1, 2, 2, 3, 3, 3, 3, 4]
tol = 0.5
self.assertTrue(type(segment(x, tol)) is np.ndarray)
self.assertTrue(isinstance(segment(x, tol), np.ndarray))
self.assertTrue(np.all(segment(x, tol) == np.array([1, 3, 2, 4, 1])))
self.assertTrue(np.all(segment(np.array(x), tol) == np.array([1, 3, 2, 4, 1])))
self.assertTrue(
Expand All @@ -317,7 +317,7 @@ def test_segment_rowsize(self):
tol = 0.5
rowsize = [6, 5]
segment_sizes = segment(x, tol, rowsize)
self.assertTrue(type(segment_sizes) is np.ndarray)
self.assertTrue(isinstance(segment_sizes, np.ndarray))
self.assertTrue(np.all(segment_sizes == np.array([1, 3, 2, 4, 1])))

def test_segment_positive_and_negative_tolerance(self):
Expand Down Expand Up @@ -762,7 +762,7 @@ def test_unpack(self):
lon = unpack(ds.lon, ds["rowsize"])

self.assertTrue(isinstance(lon, list))
self.assertTrue(np.all([type(a) is xr.DataArray for a in lon]))
self.assertTrue(np.all([isinstance(a, xr.DataArray) for a in lon]))
self.assertTrue(
np.all([lon[n].size == ds["rowsize"][n] for n in range(len(lon))])
)
Expand All @@ -771,7 +771,7 @@ def test_unpack(self):
lon = unpack(ds.lon.values, ds["rowsize"])

self.assertTrue(isinstance(lon, list))
self.assertTrue(np.all([type(a) is np.ndarray for a in lon]))
self.assertTrue(np.all([isinstance(a, np.ndarray) for a in lon]))
self.assertTrue(
np.all([lon[n].size == ds["rowsize"][n] for n in range(len(lon))])
)
Expand Down
2 changes: 1 addition & 1 deletion tests/signal_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def test_size(self):
def test_imag(self):
x = np.random.rand(99) + 1j * np.random.rand(99)
z = analytic_signal(x)
self.assertTrue(type(z) is tuple)
self.assertTrue(isinstance(z, tuple))

def test_real_odd(self):
x = np.random.rand(99)
Expand Down

0 comments on commit 079ae36

Please sign in to comment.