Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: Allow for Alternative and Custom ODE Solvers. #748

Merged
merged 3 commits into from
Dec 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .vscode/settings.json
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,7 @@
"pytest",
"pytz",
"quantile",
"Radau",
"Rdot",
"referece",
"relativetoground",
Expand Down
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ Attention: The newest changes should be on top -->

### Added

-
- ENH: Allow for Alternative and Custom ODE Solvers. [#748](https://github.com/RocketPy-Team/RocketPy/pull/748)

### Changed

Expand Down
74 changes: 63 additions & 11 deletions rocketpy/simulation/flight.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

import numpy as np
import simplekml
from scipy import integrate
from scipy.integrate import BDF, DOP853, LSODA, RK23, RK45, OdeSolver, Radau

from ..mathutils.function import Function, funcify_method
from ..mathutils.vector_matrix import Matrix, Vector
Expand All @@ -24,8 +24,19 @@
quaternions_to_spin,
)

ODE_SOLVER_MAP = {
'RK23': RK23,
'RK45': RK45,
'DOP853': DOP853,
'Radau': Radau,
'BDF': BDF,
'LSODA': LSODA,
}
phmbressan marked this conversation as resolved.
Show resolved Hide resolved

class Flight: # pylint: disable=too-many-public-methods

# pylint: disable=too-many-public-methods
# pylint: disable=too-many-instance-attributes
class Flight:
"""Keeps all flight information and has a method to simulate flight.

Attributes
Expand Down Expand Up @@ -506,6 +517,7 @@ def __init__( # pylint: disable=too-many-arguments,too-many-statements
verbose=False,
name="Flight",
equations_of_motion="standard",
ode_solver="LSODA",
):
"""Run a trajectory simulation.

Expand Down Expand Up @@ -581,10 +593,23 @@ def __init__( # pylint: disable=too-many-arguments,too-many-statements
more restricted set of equations of motion that only works for
solid propulsion rockets. Such equations were used in RocketPy v0
and are kept here for backwards compatibility.
ode_solver : str, ``scipy.integrate.OdeSolver``, optional
Integration method to use to solve the equations of motion ODE.
Available options are: 'RK23', 'RK45', 'DOP853', 'Radau', 'BDF',
'LSODA' from ``scipy.integrate.solve_ivp``.
Default is 'LSODA', which is recommended for most flights.
A custom ``scipy.integrate.OdeSolver`` can be passed as well.
For more information on the integration methods, see the scipy
documentation [1]_.


Returns
-------
None

References
----------
.. [1] https://docs.scipy.org/doc/scipy/reference/generated/scipy.integrate.solve_ivp.html
phmbressan marked this conversation as resolved.
Show resolved Hide resolved
"""
# Save arguments
self.env = environment
Expand All @@ -605,6 +630,7 @@ def __init__( # pylint: disable=too-many-arguments,too-many-statements
self.terminate_on_apogee = terminate_on_apogee
self.name = name
self.equations_of_motion = equations_of_motion
self.ode_solver = ode_solver

# Controller initialization
self.__init_controllers()
Expand Down Expand Up @@ -651,15 +677,16 @@ def __simulate(self, verbose):

# Create solver for this flight phase # TODO: allow different integrators
self.function_evaluations.append(0)
phase.solver = integrate.LSODA(

phase.solver = self._solver(
phase.derivative,
t0=phase.t,
y0=self.y_sol,
t_bound=phase.time_bound,
min_step=self.min_time_step,
max_step=self.max_time_step,
rtol=self.rtol,
atol=self.atol,
max_step=self.max_time_step,
min_step=self.min_time_step,
)

# Initialize phase time nodes
Expand Down Expand Up @@ -691,13 +718,14 @@ def __simulate(self, verbose):
for node_index, node in self.time_iterator(phase.time_nodes):
# Determine time bound for this time node
node.time_bound = phase.time_nodes[node_index + 1].t
# NOTE: Setting the time bound and status for the phase solver,
# and updating its internal state for the next integration step.
phase.solver.t_bound = node.time_bound
phase.solver._lsoda_solver._integrator.rwork[0] = phase.solver.t_bound
phase.solver._lsoda_solver._integrator.call_args[4] = (
phase.solver._lsoda_solver._integrator.rwork
)
if self.__is_lsoda:
phase.solver._lsoda_solver._integrator.rwork[0] = (
phase.solver.t_bound
)
phase.solver._lsoda_solver._integrator.call_args[4] = (
phase.solver._lsoda_solver._integrator.rwork
)
phase.solver.status = "running"

# Feed required parachute and discrete controller triggers
Expand Down Expand Up @@ -1185,6 +1213,8 @@ def __init_solver_monitors(self):
self.t = self.solution[-1][0]
self.y_sol = self.solution[-1][1:]

self.__set_ode_solver(self.ode_solver)

def __init_equations_of_motion(self):
"""Initialize equations of motion."""
if self.equations_of_motion == "solid_propulsion":
Expand Down Expand Up @@ -1222,6 +1252,28 @@ def __cache_sensor_data(self):
sensor_data[sensor] = sensor.measured_data[:]
self.sensor_data = sensor_data

def __set_ode_solver(self, solver):
"""Sets the ODE solver to be used in the simulation.

Parameters
----------
solver : str, ``scipy.integrate.OdeSolver``
Integration method to use to solve the equations of motion ODE,
or a custom ``scipy.integrate.OdeSolver``.
"""
if isinstance(solver, OdeSolver):
self._solver = solver
else:
try:
self._solver = ODE_SOLVER_MAP[solver]
except KeyError as e:
raise ValueError(
f"Invalid ``ode_solver`` input: {solver}. "
f"Available options are: {', '.join(ODE_SOLVER_MAP.keys())}"
) from e

self.__is_lsoda = hasattr(self._solver, "_lsoda_solver")

@cached_property
def effective_1rl(self):
"""Original rail length minus the distance measured from nozzle exit
Expand Down
39 changes: 38 additions & 1 deletion tests/integration/test_flight.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@


@patch("matplotlib.pyplot.show")
def test_all_info(mock_show, flight_calisto_robust): # pylint: disable=unused-argument
# pylint: disable=unused-argument
def test_all_info(mock_show, flight_calisto_robust):
"""Test that the flight class is working as intended. This basically calls
the all_info() method and checks if it returns None. It is not testing if
the values are correct, but whether the method is working without errors.
Expand All @@ -27,6 +28,42 @@ def test_all_info(mock_show, flight_calisto_robust): # pylint: disable=unused-a
assert flight_calisto_robust.all_info() is None


@pytest.mark.slow
@patch("matplotlib.pyplot.show")
@pytest.mark.parametrize("solver_method", ["RK45", "DOP853", "Radau", "BDF"])
# RK23 is unstable and requires a very low tolerance to work
# pylint: disable=unused-argument
def test_all_info_different_solvers(
mock_show, calisto_robust, example_spaceport_env, solver_method
):
"""Test that the flight class is working as intended with different solver
methods. This basically calls the all_info() method and checks if it returns
None. It is not testing if the values are correct, but whether the method is
working without errors.

Parameters
----------
mock_show : unittest.mock.MagicMock
Mock object to replace matplotlib.pyplot.show
calisto_robust : rocketpy.Rocket
Rocket to be simulated. See the conftest.py file for more info.
example_spaceport_env : rocketpy.Environment
Environment to be simulated. See the conftest.py file for more info.
solver_method : str
The solver method to be used in the simulation.
"""
test_flight = Flight(
environment=example_spaceport_env,
rocket=calisto_robust,
rail_length=5.2,
inclination=85,
heading=0,
terminate_on_apogee=False,
ode_solver=solver_method,
)
assert test_flight.all_info() is None


class TestExportData:
"""Tests the export_data method of the Flight class."""

Expand Down
Loading