Skip to content

Commit

Permalink
Merge pull request #90 from yupidevs/develop
Browse files Browse the repository at this point in the history
Fix trajectory save and load when dt is estimated
  • Loading branch information
jmorgadov authored Jan 29, 2022
2 parents 1fb9b1f + 54a4ea9 commit f1f9c78
Show file tree
Hide file tree
Showing 5 changed files with 31 additions and 12 deletions.
2 changes: 1 addition & 1 deletion docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
author = 'Gustavo Viera-López, Alfredo Reyes, Jorge Morgado, Ernesto Altshuler'

# The full version, including alpha/beta/rc tags
release = '0.8.3'
release = '0.8.4'


# -- General configuration ---------------------------------------------------
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "yupi"
version = "0.8.3"
version = "0.8.4"
description = "A package for tracking and analysing objects trajectories"
authors = [
"Gustavo Viera-López <[email protected]>",
Expand Down
15 changes: 11 additions & 4 deletions tests/test_trajectory/test_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,20 @@


def test_save():
t1 = Trajectory(x=[1,2,3], y=[4,5,6])
x = [1,2,3]
y = [4,5,6]
t1 = Trajectory(x=x, y=y)

# Wrong trajectory file extension at saving
with pytest.raises(ValueError):
t1.save('t1', file_type='abc')

# Saving json
t1.save('t1', file_type='json')

# Saving csv
t1.save('t1', file_type='csv')

t2 = Trajectory(x=x, y=y, t=[0.0, 0.5, 2.0])
t2.save('t2', file_type='json')
t2.save('t2', file_type='csv')


def test_load():
Expand All @@ -27,11 +30,15 @@ def test_load():
t1 = Trajectory.load('t1.json')
for tp, point in zip(t1, [[1,4], [2,5], [3,6]]):
assert (np.array(point) == tp.r).all()
t2 = Trajectory.load('t2.json')

# Loading csv
t1 = Trajectory.load('t1.csv')
for tp, point in zip(t1, [[1,4], [2,5], [3,6]]):
assert (np.array(point) == tp.r).all()
t2 = Trajectory.load('t2.csv')

os.remove('t1.json')
os.remove('t1.csv')
os.remove('t2.json')
os.remove('t2.csv')
2 changes: 1 addition & 1 deletion yupi/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,4 +21,4 @@
'Vector'
]

__version__ = '0.8.3'
__version__ = '0.8.4'
22 changes: 17 additions & 5 deletions yupi/trajectory.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,23 +158,25 @@ def __init__(self, x: np.ndarray = None, y: np.ndarray = None,
if lengths.count(lengths[0]) != len(lengths):
raise ValueError('All input arrays must have the same shape.')

self.__dt = dt
self.dt_mean = dt
self.__t0 = t0
self.__t = data[0]
self.ang = data[1]
self.traj_id= traj_id
self.lazy = lazy

if self.__t is None:
self.dt = dt if dt is not None else 1.0
self.dt_mean = dt if dt is not None else 1.0
self.dt_std = 0
self.__v: Vector = self.r.delta / self.dt
else:
self.dt = np.mean(np.array(self.__t.delta))
self.dt_mean = np.mean(np.array(self.__t.delta))
self.dt_std = np.std(np.array(self.__t.delta))
self.__v: Vector = (self.r.delta.T / self.__t.delta).T

if t is not None and dt is not None:
if abs(self.dt - dt) > _threshold:
if abs(self.dt_mean - dt) > _threshold:
raise ValueError("You are giving 'dt' and 't' but 'dt' "
"does not match with time values delta.")
if abs(self.dt_std - 0) > _threshold:
Expand All @@ -186,6 +188,16 @@ def __init__(self, x: np.ndarray = None, y: np.ndarray = None,

self.features = Features(self)

@property
def dt(self) -> float:
"""
Returns the time between each position data value.
If the time data is not uniformly spaced it returns an
estimated value.
"""
return self.dt_mean if self.__dt is None else self.__dt

@property
def uniformly_spaced(self) -> bool:
"""bool : True if the time data is uniformly spaced"""
Expand Down Expand Up @@ -530,7 +542,7 @@ def convert_to_list(vec: Vector):
ang = None if self.ang is None else self.ang.T
json_dict = {
'id': self.traj_id,
'dt': self.dt,
'dt': self.__dt,
'r': convert_to_list(self.r.T),
'ang': convert_to_list(ang),
't': convert_to_list(self.__t)
Expand All @@ -542,7 +554,7 @@ def _save_csv(self, path):
with open(str(path), 'w', newline='') as traj_file:
writer = csv.writer(traj_file, delimiter=',')
ang_shape = 0 if self.ang is None else self.ang.shape[1]
writer.writerow([self.traj_id, self.dt, self.dim, ang_shape])
writer.writerow([self.traj_id, self.__dt, self.dim, ang_shape])
for tp in self:
row = np.hstack(np.array([tp.r, tp.ang, tp.t], dtype=object))
writer.writerow(row)
Expand Down

0 comments on commit f1f9c78

Please sign in to comment.