Skip to content

Commit

Permalink
feat: added activations in tests
Browse files Browse the repository at this point in the history
  • Loading branch information
BrunoLiegiBastonLiegi committed Nov 21, 2024
1 parent 3e52004 commit c887854
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 13 deletions.
4 changes: 2 additions & 2 deletions src/qiboml/models/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,15 +39,15 @@ def __post_init__(

if self.differentiation == "auto":
self.differentiation = BACKEND_2_DIFFERENTIATION.get(
self.backend.name, "PSR"
self.backend.platform, "PSR"
)

if self.differentiation is not None:
self.differentiation = getattr(Diff, self.differentiation)()

def forward(self, x: torch.Tensor):
if (
self.backend.name != "pytorch"
self.backend.platform != "pytorch"
or self.differentiation is not None
or not self.decoding.analytic
):
Expand Down
47 changes: 36 additions & 11 deletions tests/test_models_interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import numpy as np
import pytest
import torch
from qibo import construct_backend, hamiltonians
from qibo import hamiltonians
from qibo.config import raise_error
from qibo.symbols import Z

Expand Down Expand Up @@ -47,19 +47,45 @@ def build_linear_layer(frontend, input_dim, output_dim):


def build_sequential_model(frontend, layers, binary=False):
layers = layers[:1] + [build_activation(frontend, binary)] + layers[1:]
if frontend.__name__ == "qiboml.models.pytorch":
activation = frontend.torch.nn.Threshold(1, 0)
layers = layers[:1] + [activation] + layers[1:] if binary else layers
return frontend.torch.nn.Sequential(*layers)
elif frontend.__name__ == "qiboml.models.keras":
return frontend.keras.Sequential(layers)
else:
raise_error(RuntimeError, f"Unknown frontend {frontend}.")


def build_activation(frontend, binary=False):
if frontend.__name__ == "qiboml.models.pytorch":

class Activation(frontend.torch.nn.Module):
def forward(self, x):
# normalize
x = x / x.max()
if binary:
x = x.round().abs()
else:
# apply the tanh and rescale by pi
x = np.pi * frontend.torch.nn.functional.tanh(x)
return x

elif frontend.__name__ == "qiboml.models.keras":
pass
else:
raise_error(RuntimeError, f"Unknown frontend {frontend}.")

activation = Activation()
return activation


def random_tensor(frontend, shape, binary=False):
if frontend.__name__ == "qiboml.models.pytorch":
tensor = frontend.torch.randint(0, 2, shape) if binary else torch.randn(shape)
tensor = (
frontend.torch.randint(0, 2, shape).double()
if binary
else frontend.torch.randn(shape)
)
elif frontend.__name__ == "qiboml.models.keras":
tensor = frontend.tf.random.uniform(shape)
else:
Expand All @@ -71,16 +97,15 @@ def train_model(frontend, model, data, target):
max_epochs = 30
if frontend.__name__ == "qiboml.models.pytorch":

optimizer = torch.optim.Adam(model.parameters())
loss_f = torch.nn.MSELoss()
optimizer = frontend.torch.optim.Adam(model.parameters())
loss_f = frontend.torch.nn.MSELoss()

avg_grad, ep = 1.0, 0
shape = model(data[0]).shape
while ep < max_epochs:
ep += 1
avg_grad = 0.0
avg_loss = 0.0
permutation = frontend.torch.randint(0, len(data), (len(data),))
permutation = frontend.torch.randperm(len(data))
for x, y in zip(data[permutation], target[permutation]):
optimizer.zero_grad()
loss = loss_f(model(x), y)
Expand Down Expand Up @@ -113,7 +138,7 @@ def eval_model(frontend, model, data, target=None):
outputs = []

if frontend.__name__ == "qiboml.models.pytorch":
loss_f = torch.nn.MSELoss()
loss_f = frontend.torch.nn.MSELoss()
with torch.no_grad():
for x in data:
outputs.append(model(x))
Expand Down Expand Up @@ -184,7 +209,7 @@ def backprop_test(frontend, model, data, target):
def test_encoding(backend, frontend, layer, seed):
if frontend.__name__ == "qiboml.models.keras":
pytest.skip("keras interface not ready.")
if backend.name not in ("pytorch", "jax"):
if backend.platform not in ("pytorch", "jax"):
pytest.skip("Non pytorch/jax differentiation is not working yet.")

set_seed(frontend, seed)
Expand Down Expand Up @@ -224,7 +249,7 @@ def test_encoding(backend, frontend, layer, seed):
def test_decoding(backend, frontend, layer, seed, analytic):
if frontend.__name__ == "qiboml.models.keras":
pytest.skip("keras interface not ready.")
if backend.name not in ("pytorch", "jax"):
if backend.platform not in ("pytorch", "jax"):
pytest.skip("Non pytorch/jax differentiation is not working yet.")
if analytic and not layer is dec.Expectation:
pytest.skip("Unused analytic argument.")
Expand Down

0 comments on commit c887854

Please sign in to comment.