-
Notifications
You must be signed in to change notification settings - Fork 3
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Enabling PSR, analytical diff with pytorch and moving interfaces to dedicated module #42
base: main
Are you sure you want to change the base?
Changes from 17 commits
bed5146
309a26a
5871804
c1534e1
687e843
e015b86
fa83ef2
ea51b91
a412b5e
bb7d7a0
5880001
6110339
c5af099
1b75cae
bd228bc
a9e2481
46de4e9
69b8f12
52ec4cd
2e341f1
9bdf627
16df115
d279f13
1842a1d
5ae1ade
7552520
ae04a26
921485f
590cf5c
45b2063
55f7a06
208ea10
eac535e
6e3870d
3a8e66f
1c0c3c8
f8d944a
73dd6c7
824613f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
This file was deleted.
Large diffs are not rendered by default.
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -14,10 +14,10 @@ packages = [{ include = "qiboml", from = "src" }] | |
python = ">=3.9,<3.13" | ||
numpy = "^1.26.4" | ||
keras = { version = "^3.0.0", optional = true } | ||
tensorflow = { version = "^2.16.1", markers = "sys_platform == 'linux' or sys_platform == 'darwin'", optional = true } | ||
tensorflow = { version = "^2.16.1", markers = "sys_platform == 'linux' or sys_platform == 'darwin'"} | ||
# TODO: the marker is a temporary solution due to the lack of the tensorflow-io 0.32.0's wheels for Windows, this package is one of | ||
# the tensorflow requirements | ||
torch = { version = "^2.3.1", optional = true } | ||
torch = { version = "^2.3.1"} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. same here |
||
qibo = {git="https://github.com/qiboteam/qibo"} | ||
jax = "^0.4.25" | ||
jaxlib = "^0.4.25" | ||
|
@@ -33,11 +33,10 @@ pdbpp = "^0.10.3" | |
optional = true | ||
|
||
[tool.poetry.group.tests.dependencies] | ||
torch = "^2.3.1" | ||
tensorflow = { version = "^2.16.1", markers = "sys_platform == 'linux'" } | ||
pytest = "^7.2.1" | ||
pylint = "3.1.0" | ||
pytest-cov = "4.0.0" | ||
qibojit = "^0.1.7" | ||
|
||
[tool.poetry.group.benchmark.dependencies] | ||
pytest-benchmark = { version = "^4.0.0", extras = ["histogram"] } | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,21 +2,13 @@ | |
|
||
from dataclasses import dataclass | ||
|
||
import numpy as np | ||
import torch | ||
from qibo import Circuit | ||
from qibo.backends import Backend, _check_backend | ||
from qibo.config import raise_error | ||
from qibo.backends import Backend | ||
|
||
from qiboml.models.decoding import QuantumDecoding | ||
from qiboml.models.encoding import QuantumEncoding | ||
from qiboml.operations import differentiation as Diff | ||
|
||
BACKEND_2_DIFFERENTIATION = { | ||
"pytorch": None, | ||
"qibolab": "PSR", | ||
"jax": "Jax", | ||
} | ||
from qiboml.operations.differentiation import DifferentiationRule | ||
|
||
|
||
@dataclass(eq=False) | ||
|
@@ -25,30 +17,23 @@ class QuantumModel(torch.nn.Module): | |
encoding: QuantumEncoding | ||
circuit: Circuit | ||
decoding: QuantumDecoding | ||
differentiation: str = "auto" | ||
differentiation_rule: DifferentiationRule = None | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I am not a fan of calling these |
||
|
||
def __post_init__( | ||
self, | ||
): | ||
super().__init__() | ||
|
||
circuit = self.encoding.circuit | ||
params = [p for param in self.circuit.get_parameters() for p in param] | ||
params = torch.as_tensor(self.backend.to_numpy(params)).ravel() | ||
params = torch.as_tensor(self.backend.to_numpy(x=params)).ravel() | ||
params.requires_grad = True | ||
self.circuit_parameters = torch.nn.Parameter(params) | ||
|
||
if self.differentiation == "auto": | ||
self.differentiation = BACKEND_2_DIFFERENTIATION.get( | ||
self.backend.name, "PSR" | ||
) | ||
|
||
if self.differentiation is not None: | ||
self.differentiation = getattr(Diff, self.differentiation)() | ||
Comment on lines
-40
to
-46
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. what happens now if the backend is not There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
|
||
def forward(self, x: torch.Tensor): | ||
if ( | ||
self.backend.name != "pytorch" | ||
or self.differentiation is not None | ||
or self.differentiation_rule is not None | ||
or not self.decoding.analytic | ||
): | ||
x = QuantumModelAutoGrad.apply( | ||
|
@@ -57,7 +42,7 @@ def forward(self, x: torch.Tensor): | |
self.circuit, | ||
self.decoding, | ||
self.backend, | ||
self.differentiation, | ||
self.differentiation_rule, | ||
*list(self.parameters())[0], | ||
) | ||
else: | ||
|
@@ -93,15 +78,15 @@ def forward( | |
circuit: Circuit, | ||
decoding: QuantumDecoding, | ||
backend, | ||
differentiation, | ||
differentiation_rule, | ||
*parameters: list[torch.nn.Parameter], | ||
): | ||
ctx.save_for_backward(x, *parameters) | ||
ctx.encoding = encoding | ||
ctx.circuit = circuit | ||
ctx.decoding = decoding | ||
ctx.backend = backend | ||
ctx.differentiation = differentiation | ||
ctx.differentiation_rule = differentiation_rule | ||
x_clone = x.clone().detach().cpu().numpy() | ||
x_clone = backend.cast(x_clone, dtype=x_clone.dtype) | ||
params = [ | ||
|
@@ -127,7 +112,7 @@ def backward(ctx, grad_output: torch.Tensor): | |
] | ||
grad_input, *gradients = ( | ||
torch.as_tensor(ctx.backend.to_numpy(grad).tolist()) | ||
for grad in ctx.differentiation.evaluate( | ||
for grad in ctx.differentiation_rule.evaluate( | ||
x_clone, ctx.encoding, ctx.circuit, ctx.decoding, ctx.backend, *params | ||
) | ||
) | ||
|
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -14,7 +14,7 @@ class QuantumDecoding: | |||||
|
||||||
nqubits: int | ||||||
qubits: list[int] = None | ||||||
nshots: int = 1000 | ||||||
nshots: int = None | ||||||
analytic: bool = True | ||||||
backend: Backend = None | ||||||
_circuit: Circuit = None | ||||||
|
@@ -58,7 +58,6 @@ def output_shape(self): | |||||
class Expectation(QuantumDecoding): | ||||||
|
||||||
observable: Union[ndarray, Hamiltonian] = None | ||||||
analytic: bool = False | ||||||
MatteoRobbiati marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
|
||||||
def __post_init__(self): | ||||||
if self.observable is None: | ||||||
|
@@ -69,7 +68,7 @@ def __post_init__(self): | |||||
super().__post_init__() | ||||||
|
||||||
def __call__(self, x: Circuit) -> ndarray: | ||||||
if self.analytic: | ||||||
if self.nshots is None: | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
better to use the property now that we have it |
||||||
return self.observable.expectation( | ||||||
super().__call__(x).state(), | ||||||
).reshape(1, 1) | ||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would leave this optional
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@MatteoRobbiati I would keep this optional
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Reminder: add it to test deps