diff --git a/src/qiboml/models/pytorch.py b/src/qiboml/models/pytorch.py index e7b9998..d1e624c 100644 --- a/src/qiboml/models/pytorch.py +++ b/src/qiboml/models/pytorch.py @@ -39,7 +39,7 @@ def __post_init__( if self.differentiation == "auto": self.differentiation = BACKEND_2_DIFFERENTIATION.get( - self.backend.name, "PSR" + self.backend.platform, "PSR" ) if self.differentiation is not None: @@ -47,7 +47,7 @@ def __post_init__( def forward(self, x: torch.Tensor): if ( - self.backend.name != "pytorch" + self.backend.platform != "pytorch" or self.differentiation is not None or not self.decoding.analytic ): diff --git a/tests/test_models_interfaces.py b/tests/test_models_interfaces.py index f919243..7fedbc7 100644 --- a/tests/test_models_interfaces.py +++ b/tests/test_models_interfaces.py @@ -4,7 +4,7 @@ import numpy as np import pytest import torch -from qibo import construct_backend, hamiltonians +from qibo import hamiltonians from qibo.config import raise_error from qibo.symbols import Z @@ -47,9 +47,8 @@ def build_linear_layer(frontend, input_dim, output_dim): def build_sequential_model(frontend, layers, binary=False): + layers = layers[:1] + [build_activation(frontend, binary)] + layers[1:] if frontend.__name__ == "qiboml.models.pytorch": - activation = frontend.torch.nn.Threshold(1, 0) - layers = layers[:1] + [activation] + layers[1:] if binary else layers return frontend.torch.nn.Sequential(*layers) elif frontend.__name__ == "qiboml.models.keras": return frontend.keras.Sequential(layers) @@ -57,9 +56,36 @@ def build_sequential_model(frontend, layers, binary=False): raise_error(RuntimeError, f"Unknown frontend {frontend}.") +def build_activation(frontend, binary=False): + if frontend.__name__ == "qiboml.models.pytorch": + + class Activation(frontend.torch.nn.Module): + def forward(self, x): + # normalize + x = x / x.max() + if binary: + x = x.round().abs() + else: + # apply the tanh and rescale by pi + x = np.pi * frontend.torch.nn.functional.tanh(x) + return x + + elif frontend.__name__ == "qiboml.models.keras": + pass + else: + raise_error(RuntimeError, f"Unknown frontend {frontend}.") + + activation = Activation() + return activation + + def random_tensor(frontend, shape, binary=False): if frontend.__name__ == "qiboml.models.pytorch": - tensor = frontend.torch.randint(0, 2, shape) if binary else torch.randn(shape) + tensor = ( + frontend.torch.randint(0, 2, shape).double() + if binary + else frontend.torch.randn(shape) + ) elif frontend.__name__ == "qiboml.models.keras": tensor = frontend.tf.random.uniform(shape) else: @@ -71,16 +97,15 @@ def train_model(frontend, model, data, target): max_epochs = 30 if frontend.__name__ == "qiboml.models.pytorch": - optimizer = torch.optim.Adam(model.parameters()) - loss_f = torch.nn.MSELoss() + optimizer = frontend.torch.optim.Adam(model.parameters()) + loss_f = frontend.torch.nn.MSELoss() avg_grad, ep = 1.0, 0 - shape = model(data[0]).shape while ep < max_epochs: ep += 1 avg_grad = 0.0 avg_loss = 0.0 - permutation = frontend.torch.randint(0, len(data), (len(data),)) + permutation = frontend.torch.randperm(len(data)) for x, y in zip(data[permutation], target[permutation]): optimizer.zero_grad() loss = loss_f(model(x), y) @@ -113,7 +138,7 @@ def eval_model(frontend, model, data, target=None): outputs = [] if frontend.__name__ == "qiboml.models.pytorch": - loss_f = torch.nn.MSELoss() + loss_f = frontend.torch.nn.MSELoss() with torch.no_grad(): for x in data: outputs.append(model(x)) @@ -184,7 +209,7 @@ def backprop_test(frontend, model, data, target): def test_encoding(backend, frontend, layer, seed): if frontend.__name__ == "qiboml.models.keras": pytest.skip("keras interface not ready.") - if backend.name not in ("pytorch", "jax"): + if backend.platform not in ("pytorch", "jax"): pytest.skip("Non pytorch/jax differentiation is not working yet.") set_seed(frontend, seed) @@ -224,7 +249,7 @@ def test_encoding(backend, frontend, layer, seed): def test_decoding(backend, frontend, layer, seed, analytic): if frontend.__name__ == "qiboml.models.keras": pytest.skip("keras interface not ready.") - if backend.name not in ("pytorch", "jax"): + if backend.platform not in ("pytorch", "jax"): pytest.skip("Non pytorch/jax differentiation is not working yet.") if analytic and not layer is dec.Expectation: pytest.skip("Unused analytic argument.")