From 02c3d74420e102bb40ed07b8c8aa412320da0c1d Mon Sep 17 00:00:00 2001 From: Yogiraj Gutte Date: Fri, 6 Dec 2024 01:58:07 +0530 Subject: [PATCH 1/5] MNT: move piecewise functions to separate file closes #667 --- rocketpy/mathutils/__init__.py | 2 +- rocketpy/mathutils/function.py | 103 ---------------------- rocketpy/mathutils/piecewise_function.py | 104 +++++++++++++++++++++++ rocketpy/motors/tank_geometry.py | 3 +- 4 files changed, 107 insertions(+), 105 deletions(-) create mode 100644 rocketpy/mathutils/piecewise_function.py diff --git a/rocketpy/mathutils/__init__.py b/rocketpy/mathutils/__init__.py index fad155583..3bac9227d 100644 --- a/rocketpy/mathutils/__init__.py +++ b/rocketpy/mathutils/__init__.py @@ -1,7 +1,7 @@ from .function import ( Function, - PiecewiseFunction, funcify_method, reset_funcified_methods, ) from .vector_matrix import Matrix, Vector +from .piecewise_function import PiecewiseFunction \ No newline at end of file diff --git a/rocketpy/mathutils/function.py b/rocketpy/mathutils/function.py index ca6005cf3..8ae7a2100 100644 --- a/rocketpy/mathutils/function.py +++ b/rocketpy/mathutils/function.py @@ -3419,109 +3419,6 @@ def __validate_extrapolation(self, extrapolation): return extrapolation -class PiecewiseFunction(Function): - """Class for creating piecewise functions. These kind of functions are - defined by a dictionary of functions, where the keys are tuples that - represent the domain of the function. The domains must be disjoint. - """ - - def __new__( - cls, - source, - inputs=None, - outputs=None, - interpolation="spline", - extrapolation=None, - datapoints=100, - ): - """ - Creates a piecewise function from a dictionary of functions. The keys of - the dictionary must be tuples that represent the domain of the function. - The domains must be disjoint. The piecewise function will be evaluated - at datapoints points to create Function object. - - Parameters - ---------- - source: dictionary - A dictionary of Function objects, where the keys are the domains. - inputs : list of strings - A list of strings that represent the inputs of the function. - outputs: list of strings - A list of strings that represent the outputs of the function. - interpolation: str - The type of interpolation to use. The default value is 'spline'. - extrapolation: str - The type of extrapolation to use. The default value is None. - datapoints: int - The number of points in which the piecewise function will be - evaluated to create a base function. The default value is 100. - """ - if inputs is None: - inputs = ["Scalar"] - if outputs is None: - outputs = ["Scalar"] - # Check if source is a dictionary - if not isinstance(source, dict): - raise TypeError("source must be a dictionary") - # Check if all keys are tuples - for key in source.keys(): - if not isinstance(key, tuple): - raise TypeError("keys of source must be tuples") - # Check if all domains are disjoint - for key1 in source.keys(): - for key2 in source.keys(): - if key1 != key2: - if key1[0] < key2[1] and key1[1] > key2[0]: - raise ValueError("domains must be disjoint") - - # Crate Function - def calc_output(func, inputs): - """Receives a list of inputs value and a function, populates another - list with the results corresponding to the same results. - - Parameters - ---------- - func : Function - The Function object to be - inputs : list, tuple, np.array - The array of points to applied the func to. - - Examples - -------- - >>> inputs = [0, 1, 2, 3, 4, 5] - >>> def func(x): - ... return x*10 - >>> calc_output(func, inputs) - [0, 10, 20, 30, 40, 50] - - Notes - ----- - In the future, consider using the built-in map function from python. - """ - output = np.zeros(len(inputs)) - for j, value in enumerate(inputs): - output[j] = func.get_value_opt(value) - return output - - input_data = [] - output_data = [] - for key in sorted(source.keys()): - i = np.linspace(key[0], key[1], datapoints) - i = i[~np.isin(i, input_data)] - input_data = np.concatenate((input_data, i)) - - f = Function(source[key]) - output_data = np.concatenate((output_data, calc_output(f, i))) - - return Function( - np.concatenate(([input_data], [output_data])).T, - inputs=inputs, - outputs=outputs, - interpolation=interpolation, - extrapolation=extrapolation, - ) - - def funcify_method(*args, **kwargs): # pylint: disable=too-many-statements """Decorator factory to wrap methods as Function objects and save them as cached properties. diff --git a/rocketpy/mathutils/piecewise_function.py b/rocketpy/mathutils/piecewise_function.py new file mode 100644 index 000000000..b64b27ad1 --- /dev/null +++ b/rocketpy/mathutils/piecewise_function.py @@ -0,0 +1,104 @@ +import numpy as np +from rocketpy.mathutils.function import Function + +class PiecewiseFunction(Function): + """Class for creating piecewise functions. These kind of functions are + defined by a dictionary of functions, where the keys are tuples that + represent the domain of the function. The domains must be disjoint. + """ + + def __new__( + cls, + source, + inputs=None, + outputs=None, + interpolation="spline", + extrapolation=None, + datapoints=100, + ): + """ + Creates a piecewise function from a dictionary of functions. The keys of + the dictionary must be tuples that represent the domain of the function. + The domains must be disjoint. The piecewise function will be evaluated + at datapoints points to create Function object. + + Parameters + ---------- + source: dictionary + A dictionary of Function objects, where the keys are the domains. + inputs : list of strings + A list of strings that represent the inputs of the function. + outputs: list of strings + A list of strings that represent the outputs of the function. + interpolation: str + The type of interpolation to use. The default value is 'spline'. + extrapolation: str + The type of extrapolation to use. The default value is None. + datapoints: int + The number of points in which the piecewise function will be + evaluated to create a base function. The default value is 100. + """ + if inputs is None: + inputs = ["Scalar"] + if outputs is None: + outputs = ["Scalar"] + # Check if source is a dictionary + if not isinstance(source, dict): + raise TypeError("source must be a dictionary") + # Check if all keys are tuples + for key in source.keys(): + if not isinstance(key, tuple): + raise TypeError("keys of source must be tuples") + # Check if all domains are disjoint + for key1 in source.keys(): + for key2 in source.keys(): + if key1 != key2: + if key1[0] < key2[1] and key1[1] > key2[0]: + raise ValueError("domains must be disjoint") + + # Crate Function + def calc_output(func, inputs): + """Receives a list of inputs value and a function, populates another + list with the results corresponding to the same results. + + Parameters + ---------- + func : Function + The Function object to be + inputs : list, tuple, np.array + The array of points to applied the func to. + + Examples + -------- + >>> inputs = [0, 1, 2, 3, 4, 5] + >>> def func(x): + ... return x*10 + >>> calc_output(func, inputs) + [0, 10, 20, 30, 40, 50] + + Notes + ----- + In the future, consider using the built-in map function from python. + """ + output = np.zeros(len(inputs)) + for j, value in enumerate(inputs): + output[j] = func.get_value_opt(value) + return output + + input_data = [] + output_data = [] + for key in sorted(source.keys()): + i = np.linspace(key[0], key[1], datapoints) + i = i[~np.isin(i, input_data)] + input_data = np.concatenate((input_data, i)) + + f = Function(source[key]) + output_data = np.concatenate((output_data, calc_output(f, i))) + + return Function( + np.concatenate(([input_data], [output_data])).T, + inputs=inputs, + outputs=outputs, + interpolation=interpolation, + extrapolation=extrapolation, + ) \ No newline at end of file diff --git a/rocketpy/motors/tank_geometry.py b/rocketpy/motors/tank_geometry.py index 272f8fc93..4fd5910c3 100644 --- a/rocketpy/motors/tank_geometry.py +++ b/rocketpy/motors/tank_geometry.py @@ -2,7 +2,8 @@ import numpy as np -from ..mathutils.function import Function, PiecewiseFunction, funcify_method +from ..mathutils.function import Function, funcify_method +from ..mathutils.piecewise_function import PiecewiseFunction from ..plots.tank_geometry_plots import _TankGeometryPlots from ..prints.tank_geometry_prints import _TankGeometryPrints From 920f15c36bd9f39b9114ed4595420a080c920e3a Mon Sep 17 00:00:00 2001 From: Yogiraj Gutte Date: Fri, 6 Dec 2024 02:12:12 +0530 Subject: [PATCH 2/5] improved import for linting --- rocketpy/mathutils/__init__.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/rocketpy/mathutils/__init__.py b/rocketpy/mathutils/__init__.py index 3bac9227d..5daf1ff80 100644 --- a/rocketpy/mathutils/__init__.py +++ b/rocketpy/mathutils/__init__.py @@ -3,5 +3,5 @@ funcify_method, reset_funcified_methods, ) -from .vector_matrix import Matrix, Vector -from .piecewise_function import PiecewiseFunction \ No newline at end of file +from .piecewise_function import PiecewiseFunction +from .vector_matrix import Matrix, Vector \ No newline at end of file From 5033694129a5a5e9983ce742074f7c706a7bf2e8 Mon Sep 17 00:00:00 2001 From: Lucas de Oliveira Prates Date: Sat, 14 Dec 2024 19:24:52 -0300 Subject: [PATCH 3/5] MNT: applying code formaters --- rocketpy/mathutils/__init__.py | 8 ++------ rocketpy/mathutils/piecewise_function.py | 4 +++- 2 files changed, 5 insertions(+), 7 deletions(-) diff --git a/rocketpy/mathutils/__init__.py b/rocketpy/mathutils/__init__.py index 5daf1ff80..181b40e55 100644 --- a/rocketpy/mathutils/__init__.py +++ b/rocketpy/mathutils/__init__.py @@ -1,7 +1,3 @@ -from .function import ( - Function, - funcify_method, - reset_funcified_methods, -) +from .function import Function, funcify_method, reset_funcified_methods from .piecewise_function import PiecewiseFunction -from .vector_matrix import Matrix, Vector \ No newline at end of file +from .vector_matrix import Matrix, Vector diff --git a/rocketpy/mathutils/piecewise_function.py b/rocketpy/mathutils/piecewise_function.py index b64b27ad1..79ebd1b71 100644 --- a/rocketpy/mathutils/piecewise_function.py +++ b/rocketpy/mathutils/piecewise_function.py @@ -1,6 +1,8 @@ import numpy as np + from rocketpy.mathutils.function import Function + class PiecewiseFunction(Function): """Class for creating piecewise functions. These kind of functions are defined by a dictionary of functions, where the keys are tuples that @@ -101,4 +103,4 @@ def calc_output(func, inputs): outputs=outputs, interpolation=interpolation, extrapolation=extrapolation, - ) \ No newline at end of file + ) From 664825fd423e1adec572ef5f57ba1691a7906dd4 Mon Sep 17 00:00:00 2001 From: Lucas de Oliveira Prates Date: Sun, 15 Dec 2024 13:18:42 -0300 Subject: [PATCH 4/5] ENH: simplifying and optimizing the function, implementing tests. --- rocketpy/mathutils/piecewise_function.py | 83 ++++++++++-------------- tests/unit/test_piecewise_function.py | 35 ++++++++++ 2 files changed, 70 insertions(+), 48 deletions(-) create mode 100644 tests/unit/test_piecewise_function.py diff --git a/rocketpy/mathutils/piecewise_function.py b/rocketpy/mathutils/piecewise_function.py index 79ebd1b71..fce26227e 100644 --- a/rocketpy/mathutils/piecewise_function.py +++ b/rocketpy/mathutils/piecewise_function.py @@ -40,62 +40,25 @@ def __new__( The number of points in which the piecewise function will be evaluated to create a base function. The default value is 100. """ + cls.__validate__source(source) if inputs is None: inputs = ["Scalar"] if outputs is None: outputs = ["Scalar"] - # Check if source is a dictionary - if not isinstance(source, dict): - raise TypeError("source must be a dictionary") - # Check if all keys are tuples - for key in source.keys(): - if not isinstance(key, tuple): - raise TypeError("keys of source must be tuples") - # Check if all domains are disjoint - for key1 in source.keys(): - for key2 in source.keys(): - if key1 != key2: - if key1[0] < key2[1] and key1[1] > key2[0]: - raise ValueError("domains must be disjoint") - - # Crate Function - def calc_output(func, inputs): - """Receives a list of inputs value and a function, populates another - list with the results corresponding to the same results. - - Parameters - ---------- - func : Function - The Function object to be - inputs : list, tuple, np.array - The array of points to applied the func to. - - Examples - -------- - >>> inputs = [0, 1, 2, 3, 4, 5] - >>> def func(x): - ... return x*10 - >>> calc_output(func, inputs) - [0, 10, 20, 30, 40, 50] - - Notes - ----- - In the future, consider using the built-in map function from python. - """ - output = np.zeros(len(inputs)) - for j, value in enumerate(inputs): - output[j] = func.get_value_opt(value) - return output input_data = [] output_data = [] - for key in sorted(source.keys()): - i = np.linspace(key[0], key[1], datapoints) - i = i[~np.isin(i, input_data)] - input_data = np.concatenate((input_data, i)) + for interval in sorted(source.keys()): + grid = np.linspace(interval[0], interval[1], datapoints) + + # since intervals are disjoint and sorted, we only need to check + # if the first point is already included + if interval[0] in input_data: + grid = np.delete(grid, 0) + input_data = np.concatenate((input_data, grid)) - f = Function(source[key]) - output_data = np.concatenate((output_data, calc_output(f, i))) + f = Function(source[interval]) + output_data = np.concatenate((output_data, f(grid))) return Function( np.concatenate(([input_data], [output_data])).T, @@ -104,3 +67,27 @@ def calc_output(func, inputs): interpolation=interpolation, extrapolation=extrapolation, ) + + @staticmethod + def __validate__source(source): + """Validates that source is dictionary with non-overlapping + intervals + + Parameters + ---------- + source : dict + A dictionary of Function objects, where the keys are the domains. + """ + # Check if source is a dictionary + if not isinstance(source, dict): + raise TypeError("source must be a dictionary") + # Check if all keys are tuples + for key in source.keys(): + if not isinstance(key, tuple): + raise TypeError("keys of source must be tuples") + # Check if all domains are disjoint + for interval1 in source.keys(): + for interval2 in source.keys(): + if interval1 != interval2: + if interval1[0] < interval2[1] and interval1[1] > interval2[0]: + raise ValueError("domains must be disjoint") diff --git a/tests/unit/test_piecewise_function.py b/tests/unit/test_piecewise_function.py new file mode 100644 index 000000000..347f3de27 --- /dev/null +++ b/tests/unit/test_piecewise_function.py @@ -0,0 +1,35 @@ +import pytest + +from rocketpy import PiecewiseFunction + + +@pytest.mark.parametrize( + "source", + [ + ((0, 4), lambda x: x), + {"0-4": lambda x: x}, + {(0, 4): lambda x: x, (3, 5): lambda x: 2 * x}, + ], +) +def test_invalid_source(source): + """Test an error is raised when the source parameter is invalid""" + with pytest.raises((TypeError, ValueError)): + PiecewiseFunction(source) + + +@pytest.mark.parametrize( + "source", + [ + {(-1, 0): lambda x: -x, (0, 1): lambda x: x}, + { + (0, 1): lambda x: x, + (1, 2): lambda x: 1, + (2, 3): lambda x: 3 - x, + }, + ], +) +@pytest.mark.parametrize("inputs", [None, "X"]) +@pytest.mark.parametrize("outputs", [None, "Y"]) +def test_new(source, inputs, outputs): + """Test if PiecewiseFunction.__new__ runs correctly""" + PiecewiseFunction(source, inputs, outputs) From a9977746ff0850ee94fa859c8aa2f042a4406261 Mon Sep 17 00:00:00 2001 From: Lucas de Oliveira Prates Date: Sun, 15 Dec 2024 20:56:22 -0300 Subject: [PATCH 5/5] MNT: update changelog and apply changes suggested in review --- CHANGELOG.md | 2 +- rocketpy/mathutils/piecewise_function.py | 25 ++++++++++++------------ 2 files changed, 14 insertions(+), 13 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 1766e2dcf..4658a75a4 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -41,7 +41,7 @@ Attention: The newest changes should be on top --> ### Changed -- +- MNT: move piecewise functions to separate file [#746](https://github.com/RocketPy-Team/RocketPy/pull/746) ### Fixed diff --git a/rocketpy/mathutils/piecewise_function.py b/rocketpy/mathutils/piecewise_function.py index fce26227e..086e6d1da 100644 --- a/rocketpy/mathutils/piecewise_function.py +++ b/rocketpy/mathutils/piecewise_function.py @@ -46,19 +46,20 @@ def __new__( if outputs is None: outputs = ["Scalar"] - input_data = [] - output_data = [] - for interval in sorted(source.keys()): - grid = np.linspace(interval[0], interval[1], datapoints) + input_data = np.array([]) + output_data = np.array([]) + for lower, upper in sorted(source.keys()): + grid = np.linspace(lower, upper, datapoints) # since intervals are disjoint and sorted, we only need to check # if the first point is already included - if interval[0] in input_data: - grid = np.delete(grid, 0) + if input_data.size != 0: + if lower == input_data[-1]: + grid = np.delete(grid, 0) input_data = np.concatenate((input_data, grid)) - f = Function(source[interval]) - output_data = np.concatenate((output_data, f(grid))) + f = Function(source[(lower, upper)]) + output_data = np.concatenate((output_data, f.get_value(grid))) return Function( np.concatenate(([input_data], [output_data])).T, @@ -86,8 +87,8 @@ def __validate__source(source): if not isinstance(key, tuple): raise TypeError("keys of source must be tuples") # Check if all domains are disjoint - for interval1 in source.keys(): - for interval2 in source.keys(): - if interval1 != interval2: - if interval1[0] < interval2[1] and interval1[1] > interval2[0]: + for lower1, upper1 in source.keys(): + for lower2, upper2 in source.keys(): + if (lower1, upper1) != (lower2, upper2): + if lower1 < upper2 and upper1 > lower2: raise ValueError("domains must be disjoint")