Skip to content

Commit

Permalink
tests
Browse files Browse the repository at this point in the history
  • Loading branch information
renatomello committed Jan 30, 2024
1 parent 9d3c10b commit 0fa60cf
Showing 1 changed file with 61 additions and 5 deletions.
66 changes: 61 additions & 5 deletions tests/test_quantum_info_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
hellinger_fidelity,
pqc_integral,
shannon_entropy,
total_variation_distance,
)


Expand Down Expand Up @@ -130,7 +131,52 @@ def test_shannon_entropy(backend, base):
backend.assert_allclose(result, 1.0)


def test_hellinger(backend):
@pytest.mark.parametrize("validate", [False, True])
def test_total_variation_distance(backend, validate):
with pytest.raises(TypeError):
prob = np.random.rand(1, 2)
prob_q = np.random.rand(1, 5)
prob = backend.cast(prob, dtype=prob.dtype)
prob_q = backend.cast(prob_q, dtype=prob_q.dtype)
test = total_variation_distance(prob, prob_q, backend=backend)
with pytest.raises(TypeError):
prob = np.random.rand(1, 2)[0]
prob_q = np.array([])
prob = backend.cast(prob, dtype=prob.dtype)
prob_q = backend.cast(prob_q, dtype=prob_q.dtype)
test = total_variation_distance(prob, prob_q, backend=backend)
with pytest.raises(ValueError):
prob = np.array([-1, 2.0])
prob_q = np.random.rand(1, 5)[0]
prob = backend.cast(prob, dtype=prob.dtype)
prob_q = backend.cast(prob_q, dtype=prob_q.dtype)
test = total_variation_distance(prob, prob_q, validate=True, backend=backend)
with pytest.raises(ValueError):
prob = np.random.rand(1, 2)[0]
prob_q = np.array([1.0, 0.0])
prob = backend.cast(prob, dtype=prob.dtype)
prob_q = backend.cast(prob_q, dtype=prob_q.dtype)
test = total_variation_distance(prob, prob_q, validate=True, backend=backend)
with pytest.raises(ValueError):
prob = np.array([1.0, 0.0])
prob_q = np.random.rand(1, 2)[0]
prob = backend.cast(prob, dtype=prob.dtype)
prob_q = backend.cast(prob_q, dtype=prob_q.dtype)
test = total_variation_distance(prob, prob_q, validate=True, backend=backend)

prob_p = np.random.rand(10)
prob_q = np.random.rand(10)
prob_p /= np.sum(prob_p)
prob_q /= np.sum(prob_q)

target = (1 / 2) * np.sum(np.abs(prob_p - prob_q))
distance = total_variation_distance(prob_p, prob_q, validate, backend=backend)

assert distance == target


@pytest.mark.parametrize("validate", [False, True])
def test_hellinger(backend, validate):
with pytest.raises(TypeError):
prob = np.random.rand(1, 2)
prob_q = np.random.rand(1, 5)
Expand Down Expand Up @@ -162,10 +208,20 @@ def test_hellinger(backend):
prob_q = backend.cast(prob_q, dtype=prob_q.dtype)
test = hellinger_distance(prob, prob_q, validate=True, backend=backend)

prob = [1.0, 0.0]
prob_q = [1.0, 0.0]
backend.assert_allclose(hellinger_distance(prob, prob_q, backend=backend), 0.0)
backend.assert_allclose(hellinger_fidelity(prob, prob_q, backend=backend), 1.0)
prob_p = np.random.rand(10)
prob_q = np.random.rand(10)
prob_p /= np.sum(prob_p)
prob_q /= np.sum(prob_q)

target = float(
backend.calculate_norm(np.sqrt(prob_p) - np.sqrt(prob_q)) / np.sqrt(2)
)

distance = hellinger_distance(prob_p, prob_q, validate=validate, backend=backend)
fidelity = hellinger_fidelity(prob_p, prob_q, validate=validate, backend=backend)

assert distance == target
assert fidelity == (1 - target**2) ** 2


def test_haar_integral_errors(backend):
Expand Down

0 comments on commit 0fa60cf

Please sign in to comment.