From e3474901f669a8bd17a04af0a7fe16972f1ac8e4 Mon Sep 17 00:00:00 2001 From: Renato Mello Date: Tue, 30 Jan 2024 17:23:09 +0400 Subject: [PATCH] test errors --- src/qibo/quantum_info/entropies.py | 9 ++++-- tests/test_quantum_info_entropies.py | 43 ++++++++++++++++++++++++++++ 2 files changed, 49 insertions(+), 3 deletions(-) diff --git a/src/qibo/quantum_info/entropies.py b/src/qibo/quantum_info/entropies.py index 6d3492d45d..ec869ba9c9 100644 --- a/src/qibo/quantum_info/entropies.py +++ b/src/qibo/quantum_info/entropies.py @@ -1,4 +1,5 @@ """Submodule with entropy measures.""" +from typing import Union import numpy as np @@ -151,7 +152,9 @@ def classical_relative_entropy(prob_dist_p, prob_dist_q, base: float = 2, backen return entropy_p - relative -def classical_renyi_entropy(prob_dist, alpha: float, base: float = 2, backend=None): +def classical_renyi_entropy( + prob_dist, alpha: Union[float, int], base: float = 2, backend=None +): """Calculates the classical Rényi entropy :math:`H_{\\alpha}` of a discrete probability distribution. For :math:`\\alpha \\in (0, \\, 1) \\cup (1, \\, \\infty)` and probability distribution @@ -171,7 +174,7 @@ def classical_renyi_entropy(prob_dist, alpha: float, base: float = 2, backend=No Args: prob_dist (ndarray): discrete probability distribution. - alpha (float): order of the Rényi entropy. + alpha (float or int): order of the Rényi entropy. If :math:`\\alpha = 1`, defaults to :func:`qibo.quantum_info.entropies.shannon_entropy`. If :math:`\\alpha = \\infty`, defaults to the `min-entropy `_. @@ -190,7 +193,7 @@ def classical_renyi_entropy(prob_dist, alpha: float, base: float = 2, backend=No # np.float64 is necessary instead of native float because of tensorflow prob_dist = backend.cast(prob_dist, dtype=np.float64) - if not isinstance(alpha, float): + if not isinstance(alpha, (float, int)): raise_error( TypeError, f"alpha must be type float, but it is type {type(alpha)}." ) diff --git a/tests/test_quantum_info_entropies.py b/tests/test_quantum_info_entropies.py index 392d4b108b..e7d19a4ef4 100644 --- a/tests/test_quantum_info_entropies.py +++ b/tests/test_quantum_info_entropies.py @@ -4,6 +4,7 @@ from qibo.config import PRECISION_TOL from qibo.quantum_info.entropies import ( classical_relative_entropy, + classical_renyi_entropy, entanglement_entropy, entropy, relative_entropy, @@ -113,6 +114,48 @@ def test_classical_relative_entropy(backend, base, kind): backend.assert_allclose(divergence, target, atol=1e-5) +@pytest.mark.parametrize("kind", [None, list]) +@pytest.mark.parametrize("base", [2, 10, np.e, 5]) +@pytest.mark.parametrize("alpha", [1, 2, 3, np.inf]) +def test_classical_renyi_entropy(backend, alpha, base, kind): + with pytest.raises(TypeError): + prob = np.array([1.0, 0.0]) + prob = backend.cast(prob, dtype=prob.dtype) + test = classical_renyi_entropy(prob, alpha="2", backend=backend) + with pytest.raises(ValueError): + prob = np.array([1.0, 0.0]) + prob = backend.cast(prob, dtype=prob.dtype) + test = classical_renyi_entropy(prob, alpha=-2, backend=backend) + with pytest.raises(TypeError): + prob = np.array([1.0, 0.0]) + prob = backend.cast(prob, dtype=prob.dtype) + test = classical_renyi_entropy(prob, alpha, base="2", backend=backend) + with pytest.raises(ValueError): + prob = np.array([1.0, 0.0]) + prob = backend.cast(prob, dtype=prob.dtype) + test = classical_renyi_entropy(prob, alpha, base=-2, backend=backend) + with pytest.raises(TypeError): + prob = np.array([[1.0], [0.0]]) + prob = backend.cast(prob, dtype=prob.dtype) + test = classical_renyi_entropy(prob, alpha, backend=backend) + with pytest.raises(TypeError): + prob = np.array([]) + prob = backend.cast(prob, dtype=prob.dtype) + test = classical_renyi_entropy(prob, alpha, backend=backend) + with pytest.raises(ValueError): + prob = np.array([1.0, -1.0]) + prob = backend.cast(prob, dtype=prob.dtype) + test = classical_renyi_entropy(prob, alpha, backend=backend) + with pytest.raises(ValueError): + prob = np.array([1.1, 0.0]) + prob = backend.cast(prob, dtype=prob.dtype) + test = classical_renyi_entropy(prob, alpha, backend=backend) + with pytest.raises(ValueError): + prob = np.array([0.5, 0.4999999]) + prob = backend.cast(prob, dtype=prob.dtype) + test = classical_renyi_entropy(prob, alpha, backend=backend) + + @pytest.mark.parametrize("check_hermitian", [False, True]) @pytest.mark.parametrize("base", [2, 10, np.e, 5]) def test_entropy(backend, base, check_hermitian):