diff --git a/src/qibo/quantum_info/basis.py b/src/qibo/quantum_info/basis.py index e3bf042a60..dae0ac2535 100644 --- a/src/qibo/quantum_info/basis.py +++ b/src/qibo/quantum_info/basis.py @@ -1,4 +1,3 @@ -from functools import reduce from itertools import product from typing import Optional @@ -92,43 +91,41 @@ def pauli_basis( backend = _check_backend(backend) pauli_labels = {"I": matrices.I, "X": matrices.X, "Y": matrices.Y, "Z": matrices.Z} - basis_single = [pauli_labels[label] for label in pauli_order] + dim = 2**nqubits + basis_single = backend.cast([pauli_labels[label] for label in pauli_order]) + einsum = np.einsum if backend.name == "tensorflow" else backend.np.einsum if nqubits > 1: - basis_full = list(product(basis_single, repeat=nqubits)) - basis_full = [reduce(np.kron, row) for row in basis_full] + input_indices = [range(3 * i, 3 * (i + 1)) for i in range(nqubits)] + output_indices = (i for indices in zip(*input_indices) for i in indices) + operands = [basis_single for _ in range(nqubits)] + inputs = [item for pair in zip(operands, input_indices) for item in pair] + basis_full = einsum(*inputs, output_indices).reshape(4**nqubits, dim, dim) else: basis_full = basis_single - basis_full = backend.cast(basis_full, dtype=basis_full[0].dtype) - if vectorize and sparse: - basis, indexes = [], [] - for row in basis_full: - row = vectorization(row, order=order, backend=backend) - row_indexes = backend.np.flatnonzero(row) - indexes.append(row_indexes) - basis.append(row[row_indexes]) - del row + if backend.name == "tensorflow": + nonzero = np.nonzero + elif backend.name == "pytorch": + nonzero = lambda x: backend.np.nonzero(x, as_tuple=True) + else: + nonzero = backend.np.nonzero + basis = vectorization(basis_full, order=order, backend=backend) + indices = nonzero(basis) + basis = basis[indices].reshape(-1, dim) + indices = indices[1].reshape(-1, dim) + elif vectorize and not sparse: - basis = [ - vectorization( - backend.cast(matrix, dtype=matrix.dtype), order=order, backend=backend - ) - for matrix in basis_full - ] + basis = vectorization(basis_full, order=order, backend=backend) else: basis = basis_full - basis = backend.cast(basis, dtype=basis[0].dtype) - if normalize: basis = basis / np.sqrt(2**nqubits) if vectorize and sparse: - indexes = backend.cast(indexes, dtype=indexes[0][0].dtype) - - return basis, indexes + return basis, indices return basis diff --git a/src/qibo/quantum_info/superoperator_transformations.py b/src/qibo/quantum_info/superoperator_transformations.py index 7d74a0ef76..dc52bc00be 100644 --- a/src/qibo/quantum_info/superoperator_transformations.py +++ b/src/qibo/quantum_info/superoperator_transformations.py @@ -27,8 +27,9 @@ def vectorization(state, order: str = "row", backend=None): .. math:: |\\rho) = \\sum_{k, l} \\, \\rho_{kl} \\, \\ket{l} \\otimes \\ket{k} + If ``state`` is a 3-dimensional tensor it is interpreted as a batch of states. Args: - state: state vector or density matrix. + state: statevector, density matrix, an array of statevectors, or an array of density matrices. order (str, optional): If ``"row"``, vectorization is performed row-wise. If ``"column"``, vectorization is performed column-wise. If ``"system"``, a block-vectorization is @@ -41,13 +42,13 @@ def vectorization(state, order: str = "row", backend=None): ndarray: Liouville representation of ``state``. """ if ( - (len(state.shape) >= 3) + (len(state.shape) > 3) or (len(state) == 0) or (len(state.shape) == 2 and state.shape[0] != state.shape[1]) ): raise_error( TypeError, - f"Object must have dims either (k,) or (k,k), but have dims {state.shape}.", + f"Object must have dims either (k,), (k, k), (N, 1, k) or (N, k, k), but have dims {state.shape}.", ) if not isinstance(order, str): @@ -63,25 +64,36 @@ def vectorization(state, order: str = "row", backend=None): backend = _check_backend(backend) + dims = state.shape[-1] + if len(state.shape) == 1: state = backend.np.outer(state, backend.np.conj(state)) + elif len(state.shape) == 3 and state.shape[1] == 1: + state = backend.np.einsum( + "aij,akl->aijkl", state, backend.np.conj(state) + ).reshape(state.shape[0], dims, dims) if order == "row": - state = backend.np.reshape(state, (1, -1))[0] + state = backend.np.reshape(state, (-1, dims**2)) elif order == "column": - state = state.T - state = backend.np.reshape(state, (1, -1))[0] + indices = list(range(len(state.shape))) + indices[-2:] = reversed(indices[-2:]) + state = backend.np.transpose(state, indices) + state = backend.np.reshape(state, (-1, dims**2)) else: - dim = len(state) - nqubits = int(np.log2(dim)) + nqubits = int(np.log2(state.shape[-1])) - new_axis = [] + new_axis = [0] for qubit in range(nqubits): - new_axis += [qubit + nqubits, qubit] + new_axis.extend([qubit + nqubits + 1, qubit + 1]) - state = backend.np.reshape(state, [2] * 2 * nqubits) + state = backend.np.reshape(state, [-1] + [2] * 2 * nqubits) state = backend.np.transpose(state, new_axis) - state = backend.np.reshape(state, (-1,)) + state = backend.np.reshape(state, (-1, 2 ** (2 * nqubits))) + + state = backend.np.squeeze( + state, axis=tuple(i for i, ax in enumerate(state.shape) if ax == 1) + ) return state diff --git a/tests/test_quantum_info_superoperator_transformations.py b/tests/test_quantum_info_superoperator_transformations.py index 8ac14aa1fd..4f382b6300 100644 --- a/tests/test_quantum_info_superoperator_transformations.py +++ b/tests/test_quantum_info_superoperator_transformations.py @@ -169,6 +169,31 @@ def test_vectorization(backend, nqubits, order, statevector): backend.assert_allclose(matrix, matrix_test, atol=PRECISION_TOL) +@pytest.mark.parametrize("order", ["row", "column", "system"]) +@pytest.mark.parametrize("nqubits", [1, 2, 3]) +@pytest.mark.parametrize("statevector", [True, False]) +def test_batched_vectorization(backend, nqubits, order, statevector): + if statevector: + state = backend.cast( + [random_statevector(2**nqubits, 42, backend=backend) for _ in range(3)] + ).reshape(3, 1, -1) + else: + state = backend.cast( + [ + random_density_matrix(2**nqubits, seed=42, backend=backend) + for _ in range(3) + ] + ) + + batched_vec = vectorization(state, order=order, backend=backend) + for i, element in enumerate(state): + if statevector: + element = element.ravel() + backend.assert_allclose( + batched_vec[i], vectorization(element, order=order, backend=backend) + ) + + @pytest.mark.parametrize("order", ["row", "column", "system"]) @pytest.mark.parametrize("nqubits", [2, 3, 4, 5]) def test_unvectorization(backend, nqubits, order):