Skip to content

Commit

Permalink
Merge branch 'develop'
Browse files Browse the repository at this point in the history
  • Loading branch information
jmorgadov committed Nov 13, 2023
2 parents f2b3c6d + ac91a27 commit 27f2cfb
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 14 deletions.
15 changes: 10 additions & 5 deletions tests/test_trajectory/test_creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,12 +86,19 @@ def test_creation_general():
def test_diff_methods():
x = [1, 2, 4, 8, 16]

Trajectory.global_diff_method(DiffMethod.LINEAR_DIFF, WindowType.FORWARD)
traj = Trajectory(x=x)

assert traj.v == pytest.approx([1, 2, 4, 8, 8], rel=APPROX_REL_TOLERANCE)
assert traj.a == pytest.approx([1, 2, 4, 0, 0], rel=APPROX_REL_TOLERANCE)

Trajectory.global_diff_method(DiffMethod.LINEAR_DIFF, WindowType.CENTRAL)
traj = Trajectory(x=x)

assert traj.v == pytest.approx([3 / 2, 3 / 2, 3, 6, 6], rel=APPROX_REL_TOLERANCE)
assert traj.a == pytest.approx(
[3 / 4, 3 / 4, 9 / 4, 3 / 2, 3 / 2], rel=APPROX_REL_TOLERANCE
)

Trajectory.global_diff_method(DiffMethod.LINEAR_DIFF)
traj.set_diff_method(DiffMethod.LINEAR_DIFF, WindowType.BACKWARD)

Expand All @@ -100,10 +107,8 @@ def test_diff_methods():

traj = Trajectory(x=x)

assert traj.v == pytest.approx([3 / 2, 3 / 2, 3, 6, 6], rel=APPROX_REL_TOLERANCE)
assert traj.a == pytest.approx(
[3 / 4, 3 / 4, 9 / 4, 3 / 2, 3 / 2], rel=APPROX_REL_TOLERANCE
)
assert traj.v == pytest.approx([1, 2, 4, 8, 8], rel=APPROX_REL_TOLERANCE)
assert traj.a == pytest.approx([1, 2, 4, 0, 0], rel=APPROX_REL_TOLERANCE)

vel_est = {
"method": DiffMethod.FORNBERG_DIFF,
Expand Down
2 changes: 1 addition & 1 deletion yupi/core/serializers/csv_serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def save(
"method", diff.DiffMethod.LINEAR_DIFF
)
diff_win = Trajectory.general_diff_est.get(
"window_type", diff.WindowType.CENTRAL
"window_type", diff.WindowType.FORWARD
)
accuracy = Trajectory.general_diff_est.get("accuracy", 1)
method = traj.diff_est.get("method", diff_method).value
Expand Down
2 changes: 1 addition & 1 deletion yupi/core/serializers/json_serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ def to_json(traj: Trajectory) -> dict:
"""

method = Trajectory.general_diff_est.get("method", diff.DiffMethod.LINEAR_DIFF)
window = Trajectory.general_diff_est.get("window_type", diff.WindowType.CENTRAL)
window = Trajectory.general_diff_est.get("window_type", diff.WindowType.FORWARD)
accuracy = Trajectory.general_diff_est.get("accuracy", 1)
diff_est = {
"method": traj.diff_est.get("method", method).value,
Expand Down
14 changes: 7 additions & 7 deletions yupi/trajectory.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ class Trajectory:

general_diff_est: Dict[str, Any] = {
"method": diff.DiffMethod.LINEAR_DIFF,
"window_type": diff.WindowType.CENTRAL,
"window_type": diff.WindowType.FORWARD,
}

def __init__(
Expand Down Expand Up @@ -261,7 +261,7 @@ def __init__(
def set_diff_method(
self,
method: diff.DiffMethod,
window_type: diff.WindowType = diff.WindowType.CENTRAL,
window_type: diff.WindowType = diff.WindowType.FORWARD,
accuracy: int = 1,
):
"""
Expand All @@ -288,7 +288,7 @@ def set_diff_method(
def set_vel_method(
self,
method: diff.DiffMethod,
window_type: diff.WindowType = diff.WindowType.CENTRAL,
window_type: diff.WindowType = diff.WindowType.FORWARD,
accuracy: int = 1,
):
"""
Expand All @@ -306,7 +306,7 @@ def set_vel_method(
@staticmethod
def global_diff_method(
method: diff.DiffMethod,
window_type: diff.WindowType = diff.WindowType.CENTRAL,
window_type: diff.WindowType = diff.WindowType.FORWARD,
accuracy: int = 1,
):
"""
Expand All @@ -332,7 +332,7 @@ def global_diff_method(
@staticmethod
def global_vel_method(
method: diff.DiffMethod,
window_type: diff.WindowType = diff.WindowType.CENTRAL,
window_type: diff.WindowType = diff.WindowType.FORWARD,
accuracy: int = 1,
):
"""
Expand Down Expand Up @@ -774,7 +774,7 @@ def convert_to_list(vec: Optional[Vector]):

diff_est = {
"method": self.diff_est.get("method", diff.DiffMethod.LINEAR_DIFF).value,
"window_type": self.diff_est.get("window", diff.WindowType.CENTRAL).value,
"window_type": self.diff_est.get("window", diff.WindowType.FORWARD).value,
"accuracy": self.diff_est.get("accuracy", 1),
}

Expand All @@ -794,7 +794,7 @@ def _save_csv(self, path: Union[str, Path]) -> None:
writer.writerow([self.traj_id, self.__dt, self.dim])

default_diff_method = diff.DiffMethod.LINEAR_DIFF
default_diff_window = diff.WindowType.CENTRAL
default_diff_window = diff.WindowType.FORWARD
default_diff_accuracy = 1
method = self.diff_est.get("method", default_diff_method).value
window = self.diff_est.get("window", default_diff_window).value
Expand Down

0 comments on commit 27f2cfb

Please sign in to comment.