From 0c5e289daa968f2887412743657f955de5959658 Mon Sep 17 00:00:00 2001 From: Shane Elipot Date: Sat, 12 Aug 2023 10:13:03 -0700 Subject: [PATCH] proposed fix for #229 --- clouddrift/analysis.py | 52 ++++++++++++++++++++---------------------- 1 file changed, 25 insertions(+), 27 deletions(-) diff --git a/clouddrift/analysis.py b/clouddrift/analysis.py index 438b31a5..9b2d2a5c 100644 --- a/clouddrift/analysis.py +++ b/clouddrift/analysis.py @@ -612,17 +612,15 @@ def position_from_velocity( f" {len(x.shape) - 1}])." ) - # Nominal order of axes on input, i.e. (0, 1, 2, ..., N-1) - target_axes = list(range(len(u.shape))) - - # If time_axis is not the last one, transpose the inputs + # If time_axis is not the last one, swap axes if time_axis != -1 and time_axis < len(u.shape) - 1: - target_axes.append(target_axes.pop(target_axes.index(time_axis))) - - # Reshape the inputs to ensure the time axis is last (fast-varying) - u_ = np.transpose(u, target_axes) - v_ = np.transpose(v, target_axes) - time_ = np.transpose(time, target_axes) + u_ = np.swapaxes(u, time_axis < -1) + v_ = np.swapaxes(v, time_axis < -1) + time_ = np.swapaxes(time, time_axis < -1) + else: + u_ = u + v_ = v + time_ = time x = np.zeros(u_.shape, dtype=u.dtype) y = np.zeros(v_.shape, dtype=v.dtype) @@ -659,10 +657,11 @@ def position_from_velocity( else: raise ValueError('coord_system must be "spherical" or "cartesian".') - if target_axes == list(range(len(u.shape))): - return x, y + # this was tested before, should we save that test to reuse here? + if time_axis != -1 and time_axis != len(x.shape) - 1: + return np.swapaxes(x, time_axis, -1), np.swapaxes(x, time_axis, -1) else: - return np.transpose(x, target_axes), np.transpose(y, target_axes) + return x, y def velocity_from_position( @@ -754,17 +753,15 @@ def velocity_from_position( f" {len(x.shape) - 1}])." ) - # Nominal order of axes on input, i.e. (0, 1, 2, ..., N-1) - target_axes = list(range(len(x.shape))) - - # If time_axis is not the last one, transpose the inputs - if time_axis != -1 and time_axis < len(x.shape) - 1: - target_axes.append(target_axes.pop(target_axes.index(time_axis))) - - # Reshape the inputs to ensure the time axis is last (fast-varying) - x_ = np.transpose(x, target_axes) - y_ = np.transpose(y, target_axes) - time_ = np.transpose(time, target_axes) + # If time_axis is not the last one, swap axes + if time_axis != -1 and time_axis < len(u.shape) - 1: + x_ = np.swapaxes(x, time_axis < -1) + y_ = np.swapaxes(y, time_axis < -1) + time_ = np.swapaxes(time, time_axis < -1) + else: + x_ = x + y_ = y + time_ = time dx = np.empty(x_.shape) dy = np.empty(y_.shape) @@ -873,10 +870,11 @@ def velocity_from_position( 'difference_scheme must be "forward", "backward", or "centered".' ) - if target_axes == list(range(len(x.shape))): - return dx / dt, dy / dt + # this was tested before, should we save that test to reuse here? + if time_axis != -1 and time_axis != len(x.shape) - 1: + return np.swapaxes(dx / dt, time_axis, -1), np.swapaxes(dy / dt, time_axis, -1) else: - return np.transpose(dx / dt, target_axes), np.transpose(dy / dt, target_axes) + return dx / dt, dy / dt def mask_var(