Skip to content

Commit

Permalink
proposed fix for #229
Browse files Browse the repository at this point in the history
  • Loading branch information
selipot committed Aug 12, 2023
1 parent dff9772 commit 0c5e289
Showing 1 changed file with 25 additions and 27 deletions.
52 changes: 25 additions & 27 deletions clouddrift/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -612,17 +612,15 @@ def position_from_velocity(
f" {len(x.shape) - 1}])."
)

# Nominal order of axes on input, i.e. (0, 1, 2, ..., N-1)
target_axes = list(range(len(u.shape)))

# If time_axis is not the last one, transpose the inputs
# If time_axis is not the last one, swap axes
if time_axis != -1 and time_axis < len(u.shape) - 1:
target_axes.append(target_axes.pop(target_axes.index(time_axis)))

# Reshape the inputs to ensure the time axis is last (fast-varying)
u_ = np.transpose(u, target_axes)
v_ = np.transpose(v, target_axes)
time_ = np.transpose(time, target_axes)
u_ = np.swapaxes(u, time_axis < -1)
v_ = np.swapaxes(v, time_axis < -1)
time_ = np.swapaxes(time, time_axis < -1)
else:
u_ = u
v_ = v
time_ = time

x = np.zeros(u_.shape, dtype=u.dtype)
y = np.zeros(v_.shape, dtype=v.dtype)
Expand Down Expand Up @@ -659,10 +657,11 @@ def position_from_velocity(
else:
raise ValueError('coord_system must be "spherical" or "cartesian".')

if target_axes == list(range(len(u.shape))):
return x, y
# this was tested before, should we save that test to reuse here?
if time_axis != -1 and time_axis != len(x.shape) - 1:
return np.swapaxes(x, time_axis, -1), np.swapaxes(x, time_axis, -1)
else:
return np.transpose(x, target_axes), np.transpose(y, target_axes)
return x, y


def velocity_from_position(
Expand Down Expand Up @@ -754,17 +753,15 @@ def velocity_from_position(
f" {len(x.shape) - 1}])."
)

# Nominal order of axes on input, i.e. (0, 1, 2, ..., N-1)
target_axes = list(range(len(x.shape)))

# If time_axis is not the last one, transpose the inputs
if time_axis != -1 and time_axis < len(x.shape) - 1:
target_axes.append(target_axes.pop(target_axes.index(time_axis)))

# Reshape the inputs to ensure the time axis is last (fast-varying)
x_ = np.transpose(x, target_axes)
y_ = np.transpose(y, target_axes)
time_ = np.transpose(time, target_axes)
# If time_axis is not the last one, swap axes
if time_axis != -1 and time_axis < len(u.shape) - 1:
x_ = np.swapaxes(x, time_axis < -1)
y_ = np.swapaxes(y, time_axis < -1)
time_ = np.swapaxes(time, time_axis < -1)
else:
x_ = x
y_ = y
time_ = time

dx = np.empty(x_.shape)
dy = np.empty(y_.shape)
Expand Down Expand Up @@ -873,10 +870,11 @@ def velocity_from_position(
'difference_scheme must be "forward", "backward", or "centered".'
)

if target_axes == list(range(len(x.shape))):
return dx / dt, dy / dt
# this was tested before, should we save that test to reuse here?
if time_axis != -1 and time_axis != len(x.shape) - 1:
return np.swapaxes(dx / dt, time_axis, -1), np.swapaxes(dy / dt, time_axis, -1)
else:
return np.transpose(dx / dt, target_axes), np.transpose(dy / dt, target_axes)
return dx / dt, dy / dt


def mask_var(
Expand Down

0 comments on commit 0c5e289

Please sign in to comment.