diff --git a/src/qibo/parameter.py b/src/qibo/parameter.py index 5346b66100..e3e511967b 100644 --- a/src/qibo/parameter.py +++ b/src/qibo/parameter.py @@ -1,5 +1,3 @@ -import inspect - import numpy as np import sympy as sp @@ -12,20 +10,19 @@ class Parameter: final gate parameter. All possible analytical derivatives of the lambda function are calculated at the object initialisation using Sympy. - Example: - .. code-block:: python + Example:: - from qibo.parameter import Parameter - param = Parameter( - lambda x, th1, th2, th3: x**2 * th1 + th2 * th3**2, - [1.5, 2.0, 3.0], - feature=[7.0], - ) + from qibo.parameter import Parameter + param = Parameter( + lambda x, th1, th2, th3: x**2 * th1 + th2 * th3**2, + [1.5, 2.0, 3.0], + feature=[7.0], + ) - partial_derivative = param.get_partial_derivative(3) + partial_derivative = param.get_partial_derivative(3) - param.update_parameters(trainable=[15.0, 10.0, 7.0], feature=[5.0]) - gate_value = param() + param.update_parameters(trainable=[15.0, 10.0, 7.0], feature=[5.0]) + gate_value = param() Args: @@ -42,7 +39,6 @@ def __init__(self, func, trainable, feature=None): # lambda function self.lambdaf = func - self._check_inputs(func) self.derivatives = self._calculate_derivatives() @@ -58,46 +54,17 @@ def nparams(self): def nfeat(self): return len(self._feature) if isinstance(self._feature, list) else 0 - def _check_inputs(self, func): - """Verifies that the inputs are correct""" - parameters = inspect.signature(func).parameters - - if (self.nfeat + self.nparams) != len(parameters): - raise_error( - ValueError, - f"The lambda function has {len(parameters)} parameters, the input has {self.nfeat+self.nparams}.", - ) - - iterator = iter(parameters.items()) - - for i in range(self.nfeat): - x = next(iterator) - if x[0][0] != "x": - raise_error( - ValueError, - f"Parameter #{i} in the lambda function should be a feature starting with `x`", - ) - - for i in range(self.nparams): - x = next(iterator) - if x[0][:2] != "th": - raise_error( - ValueError, - f"Parameter #{self.nfeat+i} in the lambda function should be a trainable parameter starting with `th`", - ) - def _apply_func(self, function, fixed_params=None): """Applies lambda function and returns final gate parameter""" params = [] if self._feature is not None: - if isinstance(self._feature, list): - params.extend(self._feature) - else: - params.append(self._feature) + params.extend(self._feature) if fixed_params: params.extend(fixed_params) else: params.extend(self._trainable) + + # run function return float(function(*params)) def _calculate_derivatives(self): @@ -119,35 +86,19 @@ def _calculate_derivatives(self): def update_parameters(self, trainable=None, feature=None): """Update gate trainable parameter and feature values""" - if not isinstance(trainable, (list, np.ndarray)): - raise_error( - ValueError, "Trainable parameters must be given as list or numpy array" - ) - - if self.nparams != len(trainable): - raise_error( - ValueError, - f"{len(trainable)} trainable parameters given, need {self.nparams}", - ) - - if not isinstance(feature, (list, np.ndarray)) and self._feature != feature: - raise_error(ValueError, "Features must be given as list or numpy array") - - if self._feature is not None and self.nfeat != len(feature): - raise_error(ValueError, f"{len(feature)} features given, need {self.nfeat}") - if trainable is not None: self._trainable = trainable - if feature and self._feature: + + if feature is not None and self._feature is not None: self._feature = feature def get_indices(self, start_index): """Return list of respective indices of trainable parameters within a larger trainable parameter list""" - return [start_index + i for i in range(self.nparams)] + return (np.arange(self.nparams) + start_index).tolist() def get_fixed_part(self, trainable_idx): - """Retrieve parameter constant unaffected by a specific trainable parameter""" + """Retrieve constant term of lambda function with regard to a specific trainable parameter""" params = self._trainable.copy() params[trainable_idx] = 0.0 return self._apply_func(self.lambdaf, fixed_params=params) diff --git a/tests/test_parameter.py b/tests/test_parameter.py index d201bd475f..f83a1f294a 100644 --- a/tests/test_parameter.py +++ b/tests/test_parameter.py @@ -54,66 +54,37 @@ def test_parameter(): def test_parameter_errors(): - with pytest.raises(ValueError) as e_info: - param = Parameter( - lambda x, t1, th2, th3: x**2 * t1 + th2 * th3**2, - [1.5, 2.0, 3.0], - feature=[7.0], - ) - - with pytest.raises(ValueError) as e_info: - param = Parameter( - lambda x, th1, th2, th3: x**2 * th1 + th2 * th3**2, - [1.5, 2.0], - feature=[3.0, 7.0], - ) - - with pytest.raises(ValueError) as e_info: - param = Parameter( - lambda j, th1, th2, th3: j**2 * th1 + th2 * th3**2, - [1.5, 2.0, 3.0], - feature=[7.0], - ) - - with pytest.raises(ValueError) as e_info: - param = Parameter( - lambda x, th1, th2, th3: x**2 * th1 + th2 * th3**2, - [1.5, 2.0, 3.0], - feature=[7.0], - ) - param.update_parameters((1, 1, 1), [1]) - - with pytest.raises(ValueError) as e_info: - param = Parameter( - lambda x, th1, th2, th3: x**2 * th1 + th2 * th3**2, - [1.5, 2.0, 3.0], - feature=[7.0], - ) - param.update_parameters([1, 1, 1], (1)) - - with pytest.raises(ValueError) as e_info: - param = Parameter( - lambda x, th1, th2, th3: x**2 * th1 + th2 * th3**2, - [1.5, 2.0, 3.0], - feature=[7.0], - ) - param.update_parameters([1, 1], [1]) - - with pytest.raises(ValueError) as e_info: - param = Parameter( - lambda x, th1, th2, th3: x**2 * th1 + th2 * th3**2, - [1.5, 2.0, 3.0], - feature=[7.0], - ) - param.update_parameters([1, 1, 1], [1, 1]) + param = Parameter( + lambda x, th1, th2, th3: x**2 * th1 + th2 * th3**2, + [1.5, 2.0, 3.0], + feature=[7.0], + ) - with pytest.raises(ValueError) as e_info: + param.update_parameters([1, 1, 1], 1) + + try: + param() + assert False + except Exception as e: + assert True + + param.update_parameters([1, 1], [1]) + try: + param() + assert False + except Exception as e: + assert True + + param.update_parameters([1, 1, 1], [1, 1]) + try: + param() + assert False + except Exception as e: + assert True + + with pytest.raises(TypeError) as e_info: param = Parameter( lambda x, th1, th2, th3: x**2 * th1 + th2 * th3**2, [1.5, 2.0], feature=[7.0], ) - - -if __name__ == "__main__": - test_parameter()