diff --git a/src/qiboml/__init__.py b/src/qiboml/__init__.py index a89d0dc..6e84c85 100644 --- a/src/qiboml/__init__.py +++ b/src/qiboml/__init__.py @@ -1,6 +1,5 @@ import importlib.metadata as im from qiboml.backends.__init__ import MetaBackend -from qiboml.models import keras, pytorch __version__ = im.version(__package__) diff --git a/src/qiboml/models/abstract.py b/src/qiboml/models/abstract.py index 305ab3e..25b2869 100644 --- a/src/qiboml/models/abstract.py +++ b/src/qiboml/models/abstract.py @@ -7,7 +7,7 @@ from qibo.config import raise_error from qibo.gates import abstract -from qiboml.backends import TensorflowBackend as JaxBackend +from qiboml.backends import JaxBackend @dataclass diff --git a/src/qiboml/models/keras.py b/src/qiboml/models/keras.py index f13c4cf..fcc7a61 100644 --- a/src/qiboml/models/keras.py +++ b/src/qiboml/models/keras.py @@ -14,25 +14,25 @@ @dataclass class QuantumModel(keras.layers.Layer): # pylint: disable=no-member - def __init__(self, layers: list[QuantumCircuitLayer]): - super().__init__() - nqubits = layers[0].circuit.nqubits - self.layers = layers - for layer in layers[1:]: - if layer.circuit.nqubits != nqubits: + layers: list[QuantumCircuitLayer] + + def __post_init__(self): + super().__post_init__() + for layer in self.layers[1:]: + if layer.circuit.nqubits != self.nqubits: raise_error( RuntimeError, - f"Layer \n{layer}\n has {layer.circuit.nqubits} qubits, but {nqubits} qubits was expected.", + f"Layer \n{layer}\n has {layer.circuit.nqubits} qubits, but {self.nqubits} qubits was expected.", ) if layer.backend.name != self.backend.name: raise_error( RuntimeError, f"Layer \n{layer}\n is using {layer.backend} backend, but {self.backend} backend was expected.", ) - if not isinstance(layers[-1], ed.QuantumDecodingLayer): + if not isinstance(self.layers[-1], ed.QuantumDecodingLayer): raise_error( RuntimeError, - f"The last layer has to be a `QuantumDecodinglayer`, but is {layers[-1]}", + f"The last layer has to be a `QuantumDecodinglayer`, but is {self.layers[-1]}", ) def call(self, x: tf.Tensor) -> tf.Tensor: diff --git a/src/qiboml/models/pytorch.py b/src/qiboml/models/pytorch.py index a052eeb..a0343c7 100644 --- a/src/qiboml/models/pytorch.py +++ b/src/qiboml/models/pytorch.py @@ -13,10 +13,11 @@ @dataclass class QuantumModel(torch.nn.Module): - def __init__(self, layers: list[QuantumCircuitLayer]): - super().__init__() - self.layers = layers - for layer in layers[1:]: + layers: list[QuantumCircuitLayer] + + def __post_init__(self): + super().__post_init__() + for layer in self.layers[1:]: if layer.circuit.nqubits != self.nqubits: raise_error( RuntimeError, @@ -27,16 +28,16 @@ def __init__(self, layers: list[QuantumCircuitLayer]): RuntimeError, f"Layer \n{layer}\n is using {layer.backend} backend, but {self.backend} backend was expected.", ) - for layer in layers: + for layer in self.layers: if len(layer.circuit.get_parameters()) > 0: self.register_parameter( layer.__class__.__name__, torch.nn.Parameter(torch.as_tensor(layer.circuit.get_parameters())), ) - if not isinstance(layers[-1], ed.QuantumDecodingLayer): + if not isinstance(self.layers[-1], ed.QuantumDecodingLayer): raise_error( RuntimeError, - f"The last layer has to be a `QuantumDecodinglayer`, but is {layers[-1]}", + f"The last layer has to be a `QuantumDecodinglayer`, but is {self.layers[-1]}", ) def forward(self, x: torch.Tensor): diff --git a/tests/conftest.py b/tests/conftest.py index 790c474..f8993a4 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -38,10 +38,14 @@ def get_backend(backend_name): def get_frontend(frontend_name): - import qiboml - - frontend = getattr(qiboml, frontend_name) - setattr(frontend, "__str__", frontend_name) + from qiboml.models import keras, pytorch + + if frontend_name == "keras": + frontend = keras + elif frontend_name == "pytorch": + frontend = pytorch + else: + raise RuntimeError(f"Unknown frontend {frontend_name}.") return frontend