Skip to content

Commit

Permalink
Merge pull request #1459 from qiboteam/pauli_basis_speedup
Browse files Browse the repository at this point in the history
Optimizations for the `qibo.quantum_info.basis.pauli_basis` and `vectorization` function
  • Loading branch information
renatomello authored Oct 8, 2024
2 parents b53ce52 + 30fb9bb commit 3a7a951
Show file tree
Hide file tree
Showing 3 changed files with 70 additions and 36 deletions.
45 changes: 21 additions & 24 deletions src/qibo/quantum_info/basis.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from functools import reduce
from itertools import product
from typing import Optional

Expand Down Expand Up @@ -92,43 +91,41 @@ def pauli_basis(
backend = _check_backend(backend)

pauli_labels = {"I": matrices.I, "X": matrices.X, "Y": matrices.Y, "Z": matrices.Z}
basis_single = [pauli_labels[label] for label in pauli_order]
dim = 2**nqubits
basis_single = backend.cast([pauli_labels[label] for label in pauli_order])
einsum = np.einsum if backend.name == "tensorflow" else backend.np.einsum

if nqubits > 1:
basis_full = list(product(basis_single, repeat=nqubits))
basis_full = [reduce(np.kron, row) for row in basis_full]
input_indices = [range(3 * i, 3 * (i + 1)) for i in range(nqubits)]
output_indices = (i for indices in zip(*input_indices) for i in indices)
operands = [basis_single for _ in range(nqubits)]
inputs = [item for pair in zip(operands, input_indices) for item in pair]
basis_full = einsum(*inputs, output_indices).reshape(4**nqubits, dim, dim)
else:
basis_full = basis_single

basis_full = backend.cast(basis_full, dtype=basis_full[0].dtype)

if vectorize and sparse:
basis, indexes = [], []
for row in basis_full:
row = vectorization(row, order=order, backend=backend)
row_indexes = backend.np.flatnonzero(row)
indexes.append(row_indexes)
basis.append(row[row_indexes])
del row
if backend.name == "tensorflow":
nonzero = np.nonzero
elif backend.name == "pytorch":
nonzero = lambda x: backend.np.nonzero(x, as_tuple=True)
else:
nonzero = backend.np.nonzero
basis = vectorization(basis_full, order=order, backend=backend)
indices = nonzero(basis)
basis = basis[indices].reshape(-1, dim)
indices = indices[1].reshape(-1, dim)

elif vectorize and not sparse:
basis = [
vectorization(
backend.cast(matrix, dtype=matrix.dtype), order=order, backend=backend
)
for matrix in basis_full
]
basis = vectorization(basis_full, order=order, backend=backend)
else:
basis = basis_full

basis = backend.cast(basis, dtype=basis[0].dtype)

if normalize:
basis = basis / np.sqrt(2**nqubits)

if vectorize and sparse:
indexes = backend.cast(indexes, dtype=indexes[0][0].dtype)

return basis, indexes
return basis, indices

return basis

Expand Down
36 changes: 24 additions & 12 deletions src/qibo/quantum_info/superoperator_transformations.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,9 @@ def vectorization(state, order: str = "row", backend=None):
.. math::
|\\rho) = \\sum_{k, l} \\, \\rho_{kl} \\, \\ket{l} \\otimes \\ket{k}
If ``state`` is a 3-dimensional tensor it is interpreted as a batch of states.
Args:
state: state vector or density matrix.
state: statevector, density matrix, an array of statevectors, or an array of density matrices.
order (str, optional): If ``"row"``, vectorization is performed
row-wise. If ``"column"``, vectorization is performed
column-wise. If ``"system"``, a block-vectorization is
Expand All @@ -41,13 +42,13 @@ def vectorization(state, order: str = "row", backend=None):
ndarray: Liouville representation of ``state``.
"""
if (
(len(state.shape) >= 3)
(len(state.shape) > 3)
or (len(state) == 0)
or (len(state.shape) == 2 and state.shape[0] != state.shape[1])
):
raise_error(
TypeError,
f"Object must have dims either (k,) or (k,k), but have dims {state.shape}.",
f"Object must have dims either (k,), (k, k), (N, 1, k) or (N, k, k), but have dims {state.shape}.",
)

if not isinstance(order, str):
Expand All @@ -63,25 +64,36 @@ def vectorization(state, order: str = "row", backend=None):

backend = _check_backend(backend)

dims = state.shape[-1]

if len(state.shape) == 1:
state = backend.np.outer(state, backend.np.conj(state))
elif len(state.shape) == 3 and state.shape[1] == 1:
state = backend.np.einsum(
"aij,akl->aijkl", state, backend.np.conj(state)
).reshape(state.shape[0], dims, dims)

if order == "row":
state = backend.np.reshape(state, (1, -1))[0]
state = backend.np.reshape(state, (-1, dims**2))
elif order == "column":
state = state.T
state = backend.np.reshape(state, (1, -1))[0]
indices = list(range(len(state.shape)))
indices[-2:] = reversed(indices[-2:])
state = backend.np.transpose(state, indices)
state = backend.np.reshape(state, (-1, dims**2))
else:
dim = len(state)
nqubits = int(np.log2(dim))
nqubits = int(np.log2(state.shape[-1]))

new_axis = []
new_axis = [0]
for qubit in range(nqubits):
new_axis += [qubit + nqubits, qubit]
new_axis.extend([qubit + nqubits + 1, qubit + 1])

state = backend.np.reshape(state, [2] * 2 * nqubits)
state = backend.np.reshape(state, [-1] + [2] * 2 * nqubits)
state = backend.np.transpose(state, new_axis)
state = backend.np.reshape(state, (-1,))
state = backend.np.reshape(state, (-1, 2 ** (2 * nqubits)))

state = backend.np.squeeze(
state, axis=tuple(i for i, ax in enumerate(state.shape) if ax == 1)
)

return state

Expand Down
25 changes: 25 additions & 0 deletions tests/test_quantum_info_superoperator_transformations.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,31 @@ def test_vectorization(backend, nqubits, order, statevector):
backend.assert_allclose(matrix, matrix_test, atol=PRECISION_TOL)


@pytest.mark.parametrize("order", ["row", "column", "system"])
@pytest.mark.parametrize("nqubits", [1, 2, 3])
@pytest.mark.parametrize("statevector", [True, False])
def test_batched_vectorization(backend, nqubits, order, statevector):
if statevector:
state = backend.cast(
[random_statevector(2**nqubits, 42, backend=backend) for _ in range(3)]
).reshape(3, 1, -1)
else:
state = backend.cast(
[
random_density_matrix(2**nqubits, seed=42, backend=backend)
for _ in range(3)
]
)

batched_vec = vectorization(state, order=order, backend=backend)
for i, element in enumerate(state):
if statevector:
element = element.ravel()
backend.assert_allclose(
batched_vec[i], vectorization(element, order=order, backend=backend)
)


@pytest.mark.parametrize("order", ["row", "column", "system"])
@pytest.mark.parametrize("nqubits", [2, 3, 4, 5])
def test_unvectorization(backend, nqubits, order):
Expand Down

0 comments on commit 3a7a951

Please sign in to comment.