Skip to content

Commit

Permalink
fix: accelerate linalg methods and cast types to doubles
Browse files Browse the repository at this point in the history
  • Loading branch information
jorgepiloto committed Sep 13, 2024
1 parent 2804c63 commit 546160a
Show file tree
Hide file tree
Showing 3 changed files with 84 additions and 16 deletions.
67 changes: 67 additions & 0 deletions src/lamberthub/linalg.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
"""Module containing various routines focused on linear algebra."""

from numba import njit as jit
import numpy as np


@jit
def dot(v1, v2):
"""Compute the dot product between two vectors.
Parameters
----------
v1 : ~np.array
First vector.
v2 : ~np.array
Second vector.
Returns
-------
float
Magnitude of the vector.
Notes
-----
This function casts the type of coordinates to double (float64) to avoid
this issue: https://github.com/numba/numba/issues/8676
"""
return v1.astype("d") @ v2.astype("d")


@jit
def cross(v1, v2):
"""Compute the cross product between two vectors.
Parameters
----------
v1 : ~np.array
First vector.
v2 : ~np.array
Second vector.
Returns
-------
~np.array
Resultant vector of the cross product between the two vectors.
"""
return np.cross(v1.astype("d"), v2.astype("d"))


@jit
def norm(vector):
"""Compute the magnitude of a vector.
Parameters
----------
vector : ~np.array
Vector whose magnitude is to be computed.
Returns
-------
float
Magnitude of the vector.
"""
return np.linalg.norm(vector.astype("d"))
7 changes: 2 additions & 5 deletions src/lamberthub/universal_solvers/izzo.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
import numpy as np
from numpy import cross, pi

from lamberthub.linalg import norm


@jit
def izzo2015(
Expand Down Expand Up @@ -374,8 +376,3 @@ def hyp2f1b(x):
if res_old == res:
return res
ii += 1


@jit
def norm(arr):
return np.sqrt(arr @ arr)
26 changes: 15 additions & 11 deletions src/lamberthub/utils/angles.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,30 @@
"""Utilities related to angles computations"""

from numba import njit as jit
import numpy as np
from numpy import cross, dot
from numpy.linalg import norm

from lamberthub.linalg import cross, dot, norm


@jit
def get_transfer_angle(r1, r2, prograde):
"""
Solves for the transfer angle being known the sense of rotation.
"""Compute the transfer angle of the trajectory.
Initial and final position vectors are required together with the direction
of motion.
Parameters
----------
r1: np.array
r1 : ~np.array
Initial position vector.
r2: np.array
r2 : ~np.array
Final position vector.
prograde: bool
If True, it assumes prograde motion, otherwise assumes retrograde.
prograde : bool
``True`` for prograde motion, ``False`` otherwise.
Returns
-------
dtheta: float
dtheta : float
Transfer angle in radians.
"""
Expand All @@ -31,7 +35,7 @@ def get_transfer_angle(r1, r2, prograde):

# Solve for a unitary vector normal to the vector plane. Its direction and
# sense the one given by the cross product (right-hand) from r1 to r2.
h = cross(r1, r2) / norm(np.cross(r1, r2))
h = np.cross(r1, r2) / norm(np.cross(r1, r2))

# Compute the projection of the normal vector onto the reference plane.
alpha = dot(np.array([0, 0, 1]), h)
Expand Down Expand Up @@ -74,7 +78,7 @@ def get_orbit_normal_vector(r1, r2, prograde):

# Solve the projection onto the positive vertical direction of the
# fundamental plane.
alpha = dot(np.array([0, 0, 1]), i_h)
alpha = np.array([0, 0, 1]) @ i_h

# An prograde orbit always has a positive vertical component of its specific
# angular momentum. Therefore, we just need to check for this condition
Expand Down

0 comments on commit 546160a

Please sign in to comment.