Skip to content

Commit

Permalink
ale comments implemented
Browse files Browse the repository at this point in the history
  • Loading branch information
acse-b99192e1 committed Sep 11, 2023
1 parent 96d1b4b commit 602d8c8
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 123 deletions.
83 changes: 17 additions & 66 deletions src/qibo/parameter.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
import inspect

import numpy as np
import sympy as sp

Expand All @@ -12,20 +10,19 @@ class Parameter:
final gate parameter. All possible analytical derivatives of the lambda function are
calculated at the object initialisation using Sympy.
Example:
.. code-block:: python
Example::
from qibo.parameter import Parameter
param = Parameter(
lambda x, th1, th2, th3: x**2 * th1 + th2 * th3**2,
[1.5, 2.0, 3.0],
feature=[7.0],
)
from qibo.parameter import Parameter
param = Parameter(
lambda x, th1, th2, th3: x**2 * th1 + th2 * th3**2,
[1.5, 2.0, 3.0],
feature=[7.0],
)
partial_derivative = param.get_partial_derivative(3)
partial_derivative = param.get_partial_derivative(3)
param.update_parameters(trainable=[15.0, 10.0, 7.0], feature=[5.0])
gate_value = param()
param.update_parameters(trainable=[15.0, 10.0, 7.0], feature=[5.0])
gate_value = param()
Args:
Expand All @@ -42,7 +39,6 @@ def __init__(self, func, trainable, feature=None):

# lambda function
self.lambdaf = func
self._check_inputs(func)

self.derivatives = self._calculate_derivatives()

Expand All @@ -58,46 +54,17 @@ def nparams(self):
def nfeat(self):
return len(self._feature) if isinstance(self._feature, list) else 0

def _check_inputs(self, func):
"""Verifies that the inputs are correct"""
parameters = inspect.signature(func).parameters

if (self.nfeat + self.nparams) != len(parameters):
raise_error(
ValueError,
f"The lambda function has {len(parameters)} parameters, the input has {self.nfeat+self.nparams}.",
)

iterator = iter(parameters.items())

for i in range(self.nfeat):
x = next(iterator)
if x[0][0] != "x":
raise_error(
ValueError,
f"Parameter #{i} in the lambda function should be a feature starting with `x`",
)

for i in range(self.nparams):
x = next(iterator)
if x[0][:2] != "th":
raise_error(
ValueError,
f"Parameter #{self.nfeat+i} in the lambda function should be a trainable parameter starting with `th`",
)

def _apply_func(self, function, fixed_params=None):
"""Applies lambda function and returns final gate parameter"""
params = []
if self._feature is not None:
if isinstance(self._feature, list):
params.extend(self._feature)
else:
params.append(self._feature)
params.extend(self._feature)
if fixed_params:
params.extend(fixed_params)
else:
params.extend(self._trainable)

# run function
return float(function(*params))

def _calculate_derivatives(self):
Expand All @@ -119,35 +86,19 @@ def _calculate_derivatives(self):

def update_parameters(self, trainable=None, feature=None):
"""Update gate trainable parameter and feature values"""
if not isinstance(trainable, (list, np.ndarray)):
raise_error(
ValueError, "Trainable parameters must be given as list or numpy array"
)

if self.nparams != len(trainable):
raise_error(
ValueError,
f"{len(trainable)} trainable parameters given, need {self.nparams}",
)

if not isinstance(feature, (list, np.ndarray)) and self._feature != feature:
raise_error(ValueError, "Features must be given as list or numpy array")

if self._feature is not None and self.nfeat != len(feature):
raise_error(ValueError, f"{len(feature)} features given, need {self.nfeat}")

if trainable is not None:
self._trainable = trainable
if feature and self._feature:

if feature is not None and self._feature is not None:
self._feature = feature

def get_indices(self, start_index):
"""Return list of respective indices of trainable parameters within
a larger trainable parameter list"""
return [start_index + i for i in range(self.nparams)]
return (np.arange(self.nparams) + start_index).tolist()

def get_fixed_part(self, trainable_idx):
"""Retrieve parameter constant unaffected by a specific trainable parameter"""
"""Retrieve constant term of lambda function with regard to a specific trainable parameter"""
params = self._trainable.copy()
params[trainable_idx] = 0.0
return self._apply_func(self.lambdaf, fixed_params=params)
Expand Down
85 changes: 28 additions & 57 deletions tests/test_parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,66 +54,37 @@ def test_parameter():


def test_parameter_errors():
with pytest.raises(ValueError) as e_info:
param = Parameter(
lambda x, t1, th2, th3: x**2 * t1 + th2 * th3**2,
[1.5, 2.0, 3.0],
feature=[7.0],
)

with pytest.raises(ValueError) as e_info:
param = Parameter(
lambda x, th1, th2, th3: x**2 * th1 + th2 * th3**2,
[1.5, 2.0],
feature=[3.0, 7.0],
)

with pytest.raises(ValueError) as e_info:
param = Parameter(
lambda j, th1, th2, th3: j**2 * th1 + th2 * th3**2,
[1.5, 2.0, 3.0],
feature=[7.0],
)

with pytest.raises(ValueError) as e_info:
param = Parameter(
lambda x, th1, th2, th3: x**2 * th1 + th2 * th3**2,
[1.5, 2.0, 3.0],
feature=[7.0],
)
param.update_parameters((1, 1, 1), [1])

with pytest.raises(ValueError) as e_info:
param = Parameter(
lambda x, th1, th2, th3: x**2 * th1 + th2 * th3**2,
[1.5, 2.0, 3.0],
feature=[7.0],
)
param.update_parameters([1, 1, 1], (1))

with pytest.raises(ValueError) as e_info:
param = Parameter(
lambda x, th1, th2, th3: x**2 * th1 + th2 * th3**2,
[1.5, 2.0, 3.0],
feature=[7.0],
)
param.update_parameters([1, 1], [1])

with pytest.raises(ValueError) as e_info:
param = Parameter(
lambda x, th1, th2, th3: x**2 * th1 + th2 * th3**2,
[1.5, 2.0, 3.0],
feature=[7.0],
)
param.update_parameters([1, 1, 1], [1, 1])
param = Parameter(
lambda x, th1, th2, th3: x**2 * th1 + th2 * th3**2,
[1.5, 2.0, 3.0],
feature=[7.0],
)

with pytest.raises(ValueError) as e_info:
param.update_parameters([1, 1, 1], 1)

try:
param()
assert False
except Exception as e:
assert True

param.update_parameters([1, 1], [1])
try:
param()
assert False
except Exception as e:
assert True

param.update_parameters([1, 1, 1], [1, 1])
try:
param()
assert False
except Exception as e:
assert True

with pytest.raises(TypeError) as e_info:
param = Parameter(
lambda x, th1, th2, th3: x**2 * th1 + th2 * th3**2,
[1.5, 2.0],
feature=[7.0],
)


if __name__ == "__main__":
test_parameter()

0 comments on commit 602d8c8

Please sign in to comment.