Skip to content

Commit

Permalink
Merge pull request #1163 from qiboteam/random_statevector
Browse files Browse the repository at this point in the history
Refactoring `random_statevector`
  • Loading branch information
renatomello authored Feb 5, 2024
2 parents c113117 + a72a265 commit f67dc84
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 36 deletions.
32 changes: 7 additions & 25 deletions src/qibo/quantum_info/random_ensembles.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,7 +384,7 @@ def random_quantum_channel(
return super_op


def random_statevector(dims: int, haar: bool = False, seed=None, backend=None):
def random_statevector(dims: int, seed=None, backend=None):
"""Creates a random statevector :math:`\\ket{\\psi}`.
.. math::
Expand All @@ -396,10 +396,6 @@ def random_statevector(dims: int, haar: bool = False, seed=None, backend=None):
Args:
dims (int): dimension of the matrix.
haar (bool, optional): if ``True``, statevector is created by sampling a
Haar random unitary :math:`U_{\\text{haar}}` and acting with it on a
random computational basis state :math:`\\ket{k}`, i.e.
:math:`\\ket{\\psi} = U_{\\text{haar}} \\ket{k}`. Defaults to ``False``.
seed (int or :class:`numpy.random.Generator`, optional): Either a generator of
random numbers or a fixed seed to initialize a generator. If ``None``,
initializes a generator with a random seed. Defaults to ``None``.
Expand All @@ -414,9 +410,6 @@ def random_statevector(dims: int, haar: bool = False, seed=None, backend=None):
if dims <= 0:
raise_error(ValueError, "dim must be of type int and >= 1")

if not isinstance(haar, bool):
raise_error(TypeError, f"haar must be type bool, but it is type {type(haar)}.")

if (
seed is not None
and not isinstance(seed, int)
Expand All @@ -426,23 +419,12 @@ def random_statevector(dims: int, haar: bool = False, seed=None, backend=None):
TypeError, "seed must be either type int or numpy.random.Generator."
)

if backend is None: # pragma: no cover
backend = GlobalBackend()

local_state = (
np.random.default_rng(seed) if seed is None or isinstance(seed, int) else seed
)
backend, local_state = _set_backend_and_local_state(seed, backend)

if not haar:
# sample real and imag parts of complex amplitude in [-1, 1]
state = 1j * (2 * local_state.random(dims) - 1)
state += 2 * local_state.random(dims) - 1
state /= np.linalg.norm(state)
state = backend.cast(state, dtype=state.dtype)
else:
# select a random column of a haar random unitary
k = local_state.integers(low=0, high=dims)
state = random_unitary(dims, measure="haar", seed=seed, backend=backend)[:, k]
state = local_state.standard_normal(dims).astype(complex)
state += 1.0j * local_state.standard_normal(dims)
state /= np.linalg.norm(state)
state = backend.cast(state, dtype=state.dtype)

return state

Expand Down Expand Up @@ -566,7 +548,7 @@ def random_density_matrix(
random_gaussian_matrix(dims, rank, seed=local_state, backend=backend),
)
state = np.dot(state, np.transpose(np.conj(state)))
state = state / np.trace(state)
state /= np.trace(state)

state = backend.cast(state, dtype=state.dtype)

Expand Down
4 changes: 1 addition & 3 deletions src/qibo/quantum_info/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -469,9 +469,7 @@ def haar_integral(
rand_unit_density, dtype=rand_unit_density.dtype
)
for _ in range(samples):
haar_state = np.reshape(
random_statevector(dim, haar=True, backend=backend), (-1, 1)
)
haar_state = np.reshape(random_statevector(dim, backend=backend), (-1, 1))

rho = haar_state @ np.conj(np.transpose(haar_state))

Expand Down
11 changes: 3 additions & 8 deletions tests/test_quantum_info_random.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,27 +223,22 @@ def test_random_quantum_channel(backend, representation, measure, rank, order):
)


@pytest.mark.parametrize("haar", [False, True])
@pytest.mark.parametrize("seed", [None, 10, np.random.default_rng(10)])
def test_random_statevector(backend, haar, seed):
def test_random_statevector(backend, seed):
with pytest.raises(TypeError):
dims = "10"
random_statevector(dims, backend=backend)
with pytest.raises(ValueError):
dims = 0
random_statevector(dims, backend=backend)
with pytest.raises(TypeError):
dims = 2
random_statevector(dims, haar=1, backend=backend)
with pytest.raises(TypeError):
dims = 2
random_statevector(dims, seed=0.1, backend=backend)

# tests if random statevector is a pure state
dims = 4
state = random_statevector(dims, haar=haar, seed=seed, backend=backend)
backend.assert_allclose(purity(state) <= 1.0 + PRECISION_TOL, True)
backend.assert_allclose(purity(state) >= 1.0 - PRECISION_TOL, True)
state = random_statevector(dims, seed=seed, backend=backend)
backend.assert_allclose(abs(purity(state) - 1.0) < PRECISION_TOL, True)


@pytest.mark.parametrize("normalize", [False, True])
Expand Down

0 comments on commit f67dc84

Please sign in to comment.