From 0c5e289daa968f2887412743657f955de5959658 Mon Sep 17 00:00:00 2001 From: Shane Elipot Date: Sat, 12 Aug 2023 10:13:03 -0700 Subject: [PATCH 1/3] proposed fix for #229 --- clouddrift/analysis.py | 52 ++++++++++++++++++++---------------------- 1 file changed, 25 insertions(+), 27 deletions(-) diff --git a/clouddrift/analysis.py b/clouddrift/analysis.py index 438b31a5..9b2d2a5c 100644 --- a/clouddrift/analysis.py +++ b/clouddrift/analysis.py @@ -612,17 +612,15 @@ def position_from_velocity( f" {len(x.shape) - 1}])." ) - # Nominal order of axes on input, i.e. (0, 1, 2, ..., N-1) - target_axes = list(range(len(u.shape))) - - # If time_axis is not the last one, transpose the inputs + # If time_axis is not the last one, swap axes if time_axis != -1 and time_axis < len(u.shape) - 1: - target_axes.append(target_axes.pop(target_axes.index(time_axis))) - - # Reshape the inputs to ensure the time axis is last (fast-varying) - u_ = np.transpose(u, target_axes) - v_ = np.transpose(v, target_axes) - time_ = np.transpose(time, target_axes) + u_ = np.swapaxes(u, time_axis < -1) + v_ = np.swapaxes(v, time_axis < -1) + time_ = np.swapaxes(time, time_axis < -1) + else: + u_ = u + v_ = v + time_ = time x = np.zeros(u_.shape, dtype=u.dtype) y = np.zeros(v_.shape, dtype=v.dtype) @@ -659,10 +657,11 @@ def position_from_velocity( else: raise ValueError('coord_system must be "spherical" or "cartesian".') - if target_axes == list(range(len(u.shape))): - return x, y + # this was tested before, should we save that test to reuse here? + if time_axis != -1 and time_axis != len(x.shape) - 1: + return np.swapaxes(x, time_axis, -1), np.swapaxes(x, time_axis, -1) else: - return np.transpose(x, target_axes), np.transpose(y, target_axes) + return x, y def velocity_from_position( @@ -754,17 +753,15 @@ def velocity_from_position( f" {len(x.shape) - 1}])." ) - # Nominal order of axes on input, i.e. (0, 1, 2, ..., N-1) - target_axes = list(range(len(x.shape))) - - # If time_axis is not the last one, transpose the inputs - if time_axis != -1 and time_axis < len(x.shape) - 1: - target_axes.append(target_axes.pop(target_axes.index(time_axis))) - - # Reshape the inputs to ensure the time axis is last (fast-varying) - x_ = np.transpose(x, target_axes) - y_ = np.transpose(y, target_axes) - time_ = np.transpose(time, target_axes) + # If time_axis is not the last one, swap axes + if time_axis != -1 and time_axis < len(u.shape) - 1: + x_ = np.swapaxes(x, time_axis < -1) + y_ = np.swapaxes(y, time_axis < -1) + time_ = np.swapaxes(time, time_axis < -1) + else: + x_ = x + y_ = y + time_ = time dx = np.empty(x_.shape) dy = np.empty(y_.shape) @@ -873,10 +870,11 @@ def velocity_from_position( 'difference_scheme must be "forward", "backward", or "centered".' ) - if target_axes == list(range(len(x.shape))): - return dx / dt, dy / dt + # this was tested before, should we save that test to reuse here? + if time_axis != -1 and time_axis != len(x.shape) - 1: + return np.swapaxes(dx / dt, time_axis, -1), np.swapaxes(dy / dt, time_axis, -1) else: - return np.transpose(dx / dt, target_axes), np.transpose(dy / dt, target_axes) + return dx / dt, dy / dt def mask_var( From 0d78e44ac9bc5a4c31e2dfc261a226e4b976eb0b Mon Sep 17 00:00:00 2001 From: Shane Elipot Date: Sat, 12 Aug 2023 10:22:11 -0700 Subject: [PATCH 2/3] fix --- clouddrift/analysis.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/clouddrift/analysis.py b/clouddrift/analysis.py index 9b2d2a5c..6bf71159 100644 --- a/clouddrift/analysis.py +++ b/clouddrift/analysis.py @@ -658,7 +658,7 @@ def position_from_velocity( raise ValueError('coord_system must be "spherical" or "cartesian".') # this was tested before, should we save that test to reuse here? - if time_axis != -1 and time_axis != len(x.shape) - 1: + if time_axis != -1 and time_axis != len(u.shape) - 1: return np.swapaxes(x, time_axis, -1), np.swapaxes(x, time_axis, -1) else: return x, y @@ -754,7 +754,7 @@ def velocity_from_position( ) # If time_axis is not the last one, swap axes - if time_axis != -1 and time_axis < len(u.shape) - 1: + if time_axis != -1 and time_axis < len(x.shape) - 1: x_ = np.swapaxes(x, time_axis < -1) y_ = np.swapaxes(y, time_axis < -1) time_ = np.swapaxes(time, time_axis < -1) From ff6ba8024d2b8b133c6eba6ea7f94fec959d6d24 Mon Sep 17 00:00:00 2001 From: Shane Elipot Date: Sat, 12 Aug 2023 10:35:22 -0700 Subject: [PATCH 3/3] typo fix --- clouddrift/analysis.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/clouddrift/analysis.py b/clouddrift/analysis.py index 6bf71159..07f98498 100644 --- a/clouddrift/analysis.py +++ b/clouddrift/analysis.py @@ -614,9 +614,9 @@ def position_from_velocity( # If time_axis is not the last one, swap axes if time_axis != -1 and time_axis < len(u.shape) - 1: - u_ = np.swapaxes(u, time_axis < -1) - v_ = np.swapaxes(v, time_axis < -1) - time_ = np.swapaxes(time, time_axis < -1) + u_ = np.swapaxes(u, time_axis , -1) + v_ = np.swapaxes(v, time_axis , -1) + time_ = np.swapaxes(time, time_axis , -1) else: u_ = u v_ = v @@ -755,9 +755,9 @@ def velocity_from_position( # If time_axis is not the last one, swap axes if time_axis != -1 and time_axis < len(x.shape) - 1: - x_ = np.swapaxes(x, time_axis < -1) - y_ = np.swapaxes(y, time_axis < -1) - time_ = np.swapaxes(time, time_axis < -1) + x_ = np.swapaxes(x, time_axis, -1) + y_ = np.swapaxes(y, time_axis, -1) + time_ = np.swapaxes(time, time_axis, -1) else: x_ = x y_ = y