Skip to content

Commit

Permalink
fix: removed keras and pytorch interface import from __init__.py
Browse files Browse the repository at this point in the history
  • Loading branch information
BrunoLiegiBastonLiegi committed Aug 3, 2024
1 parent bf116f0 commit eb39ca8
Show file tree
Hide file tree
Showing 5 changed files with 26 additions and 22 deletions.
1 change: 0 additions & 1 deletion src/qiboml/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import importlib.metadata as im

from qiboml.backends.__init__ import MetaBackend
from qiboml.models import keras, pytorch

__version__ = im.version(__package__)
2 changes: 1 addition & 1 deletion src/qiboml/models/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from qibo.config import raise_error
from qibo.gates import abstract

from qiboml.backends import TensorflowBackend as JaxBackend
from qiboml.backends import JaxBackend


@dataclass
Expand Down
18 changes: 9 additions & 9 deletions src/qiboml/models/keras.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,25 +14,25 @@
@dataclass
class QuantumModel(keras.layers.Layer): # pylint: disable=no-member

def __init__(self, layers: list[QuantumCircuitLayer]):
super().__init__()
nqubits = layers[0].circuit.nqubits
self.layers = layers
for layer in layers[1:]:
if layer.circuit.nqubits != nqubits:
layers: list[QuantumCircuitLayer]

def __post_init__(self):
super().__post_init__()
for layer in self.layers[1:]:
if layer.circuit.nqubits != self.nqubits:
raise_error(
RuntimeError,
f"Layer \n{layer}\n has {layer.circuit.nqubits} qubits, but {nqubits} qubits was expected.",
f"Layer \n{layer}\n has {layer.circuit.nqubits} qubits, but {self.nqubits} qubits was expected.",
)
if layer.backend.name != self.backend.name:
raise_error(
RuntimeError,
f"Layer \n{layer}\n is using {layer.backend} backend, but {self.backend} backend was expected.",
)
if not isinstance(layers[-1], ed.QuantumDecodingLayer):
if not isinstance(self.layers[-1], ed.QuantumDecodingLayer):
raise_error(
RuntimeError,
f"The last layer has to be a `QuantumDecodinglayer`, but is {layers[-1]}",
f"The last layer has to be a `QuantumDecodinglayer`, but is {self.layers[-1]}",
)

def call(self, x: tf.Tensor) -> tf.Tensor:
Expand Down
15 changes: 8 additions & 7 deletions src/qiboml/models/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,11 @@
@dataclass
class QuantumModel(torch.nn.Module):

def __init__(self, layers: list[QuantumCircuitLayer]):
super().__init__()
self.layers = layers
for layer in layers[1:]:
layers: list[QuantumCircuitLayer]

def __post_init__(self):
super().__post_init__()
for layer in self.layers[1:]:
if layer.circuit.nqubits != self.nqubits:
raise_error(
RuntimeError,
Expand All @@ -27,16 +28,16 @@ def __init__(self, layers: list[QuantumCircuitLayer]):
RuntimeError,
f"Layer \n{layer}\n is using {layer.backend} backend, but {self.backend} backend was expected.",
)
for layer in layers:
for layer in self.layers:
if len(layer.circuit.get_parameters()) > 0:
self.register_parameter(
layer.__class__.__name__,
torch.nn.Parameter(torch.as_tensor(layer.circuit.get_parameters())),
)
if not isinstance(layers[-1], ed.QuantumDecodingLayer):
if not isinstance(self.layers[-1], ed.QuantumDecodingLayer):
raise_error(
RuntimeError,
f"The last layer has to be a `QuantumDecodinglayer`, but is {layers[-1]}",
f"The last layer has to be a `QuantumDecodinglayer`, but is {self.layers[-1]}",
)

def forward(self, x: torch.Tensor):
Expand Down
12 changes: 8 additions & 4 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,14 @@ def get_backend(backend_name):


def get_frontend(frontend_name):
import qiboml

frontend = getattr(qiboml, frontend_name)
setattr(frontend, "__str__", frontend_name)
from qiboml.models import keras, pytorch

if frontend_name == "keras":
frontend = keras
elif frontend_name == "pytorch":
frontend = pytorch
else:
raise RuntimeError(f"Unknown frontend {frontend_name}.")
return frontend


Expand Down

0 comments on commit eb39ca8

Please sign in to comment.