Skip to content

Commit

Permalink
implemented Ale and Matteo comments
Browse files Browse the repository at this point in the history
  • Loading branch information
acse-b99192e1 committed Sep 13, 2023
1 parent 602d8c8 commit 3bfba6e
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 47 deletions.
57 changes: 35 additions & 22 deletions src/qibo/parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@


class Parameter:
"""Object which allows for variational gate parameters. Several trainable parameter
"""Object which allows for variational gate parameters. Several trainable parameters
and possibly features are linked through a lambda function which returns the
final gate parameter. All possible analytical derivatives of the lambda function are
calculated at the object initialisation using Sympy.
Expand Down Expand Up @@ -33,9 +33,10 @@ class Parameter:
feature (list or np.ndarray): array containing possible input features x
"""

def __init__(self, func, trainable, feature=None):
def __init__(self, func, features=None, trainable=None, nofeatures=False):
self._trainable = trainable
self._feature = feature
self._features = features
self._nofeatures = nofeatures

# lambda function
self.lambdaf = func
Expand All @@ -48,18 +49,23 @@ def __call__(self):

@property
def nparams(self):
return len(self._trainable)
"""Returns the number of trainable parameters"""
try:
return len(self._trainable)
except TypeError:
return 0

@property
def nfeat(self):
return len(self._feature) if isinstance(self._feature, list) else 0
"""Returns the number of features"""
return len(self._features) if isinstance(self._features, list) else 0

def _apply_func(self, function, fixed_params=None):
"""Applies lambda function and returns final gate parameter"""
params = []
if self._feature is not None:
params.extend(self._feature)
if fixed_params:
if self._features is not None:
params.extend(self._features)
if fixed_params is not None:
params.extend(fixed_params)
else:
params.extend(self._trainable)
Expand All @@ -70,10 +76,8 @@ def _apply_func(self, function, fixed_params=None):
def _calculate_derivatives(self):
"""Calculates derivatives w.r.t to all trainable parameters"""
vars = []
for i in range(self.nfeat):
vars.append(sp.Symbol(f"x{i}"))
for i in range(self.nparams):
vars.append(sp.Symbol(f"th{i}"))
for i in range(self.lambdaf.__code__.co_argcount):
vars.append(sp.Symbol(f"p{i}"))

expr = sp.sympify(self.lambdaf(*vars))

Expand All @@ -84,26 +88,35 @@ def _calculate_derivatives(self):

return derivatives

def update_parameters(self, trainable=None, feature=None):
"""Update gate trainable parameter and feature values"""
if trainable is not None:
self._trainable = trainable
def gettrainable(self):
return self._trainable

if feature is not None and self._feature is not None:
self._feature = feature
def settrainable(self, value):
self._trainable = value

def get_indices(self, start_index):
def getfeatures(self):
return self._features

def setfeatures(self, value):
self._features = value if not self._nofeatures else None

trainable = property(
gettrainable, settrainable, doc="I'm the trainable parameters property."
)
features = property(getfeatures, setfeatures, doc="I'm the features property.")

def trainable_parameter_indices(self, start_index):
"""Return list of respective indices of trainable parameters within
a larger trainable parameter list"""
the larger trainable parameter list of a circuit for example"""
return (np.arange(self.nparams) + start_index).tolist()

def get_fixed_part(self, trainable_idx):
def unaffected_by(self, trainable_idx):
"""Retrieve constant term of lambda function with regard to a specific trainable parameter"""
params = self._trainable.copy()
params[trainable_idx] = 0.0
return self._apply_func(self.lambdaf, fixed_params=params)

def get_partial_derivative(self, trainable_idx):
def partial_derivative(self, trainable_idx):
"""Get derivative w.r.t a trainable parameter"""
deriv = self.derivatives[trainable_idx]
return self._apply_func(deriv)
55 changes: 30 additions & 25 deletions tests/test_parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,28 +8,29 @@ def test_parameter():
# single feature
param = Parameter(
lambda x, th1, th2, th3: x**2 * th1 + th2 * th3**2,
[1.5, 2.0, 3.0],
feature=[7.0],
features=[7.0],
trainable=[1.5, 2.0, 3.0],
)

indices = param.get_indices(10)
indices = param.trainable_parameter_indices(10)
assert indices == [10, 11, 12]

fixed = param.get_fixed_part(1)
fixed = param.unaffected_by(1)
assert fixed == 73.5

factor = param.get_partial_derivative(3)
factor = param.partial_derivative(3)
assert factor == 12.0

param.update_parameters(trainable=[15.0, 10.0, 7.0], feature=[5.0])
param.trainable = [15.0, 10.0, 7.0]
param.features = [5.0]
gate_value = param()
assert gate_value == 865

# single feature, no list
param2 = Parameter(
lambda x, th1, th2, th3: x**2 * th1 + th2 * th3**2,
[1.5, 2.0, 3.0],
feature=[7.0],
features=[7.0],
trainable=[1.5, 2.0, 3.0],
)

gate_value2 = param2()
Expand All @@ -38,53 +39,57 @@ def test_parameter():
# multiple features
param = Parameter(
lambda x1, x2, th1, th2, th3: x1**2 * th1 + x2 * th2 * th3,
[1.5, 2.0, 3.0],
feature=[7.0, 4.0],
features=[7.0, 4.0],
trainable=[1.5, 2.0, 3.0],
)

fixed = param.get_fixed_part(1)
fixed = param.unaffected_by(1)
assert fixed == 73.5

factor = param.get_partial_derivative(4)
factor = param.partial_derivative(4)
assert factor == 8.0

param.update_parameters(trainable=np.array([15.0, 10.0, 7.0]), feature=[5.0, 3.0])
param.trainable = np.array([15.0, 10.0, 7.0])
param.features = [5.0, 3.0]
gate_value = param()
assert gate_value == 585

param = Parameter(lambda th1, th2, th3: th1 + th2 * th3, nofeatures=True)
param.trainable = [1.0, 2.0, 4.0]
param.features = [22.0]

assert param() == 9.0
assert param.features == None


def test_parameter_errors():
param = Parameter(
lambda x, th1, th2, th3: x**2 * th1 + th2 * th3**2,
[1.5, 2.0, 3.0],
feature=[7.0],
features=[7.0],
trainable=[1.5, 2.0, 3.0],
)

param.update_parameters([1, 1, 1], 1)
param.trainable = [1, 1, 1]
param.features = 1

try:
param()
assert False
except Exception as e:
assert True

param.update_parameters([1, 1], [1])
param.trainable = [1, 1]
param.features = [1]
try:
param()
assert False
except Exception as e:
assert True

param.update_parameters([1, 1, 1], [1, 1])
param.trainable = [1, 1, 1]
param.features = [1, 1]
try:
param()
assert False
except Exception as e:
assert True

with pytest.raises(TypeError) as e_info:
param = Parameter(
lambda x, th1, th2, th3: x**2 * th1 + th2 * th3**2,
[1.5, 2.0],
feature=[7.0],
)

0 comments on commit 3bfba6e

Please sign in to comment.