diff --git a/tests/test_derivative.py b/tests/test_derivative.py index a1c6e97e0e..80b477c263 100644 --- a/tests/test_derivative.py +++ b/tests/test_derivative.py @@ -57,45 +57,34 @@ def test_standard_parameter_shift(backend, nshots, atol, scale_factor, grads): circuit=c, hamiltonian=c, parameter_index=0, nshots=nshots ) - if isinstance(backend, PyTorchBackend): - with pytest.raises(NotImplementedError) as excinfo: - grad = parameter_shift( - circuit=c, hamiltonian=test_hamiltonian, parameter_index=0 - ) - assert ( - str(excinfo.value) - == "PyTorchBackend for the parameter shift rule is not supported." - ) - - else: - # executing all the procedure - grad_0 = parameter_shift( - circuit=c, - hamiltonian=test_hamiltonian, - parameter_index=0, - scale_factor=scale_factor, - nshots=nshots, - ) - grad_1 = parameter_shift( - circuit=c, - hamiltonian=test_hamiltonian, - parameter_index=1, - scale_factor=scale_factor, - nshots=nshots, - ) - grad_2 = parameter_shift( - circuit=c, - hamiltonian=test_hamiltonian, - parameter_index=2, - scale_factor=scale_factor, - nshots=nshots, - ) + # executing all the procedure + grad_0 = parameter_shift( + circuit=c, + hamiltonian=test_hamiltonian, + parameter_index=0, + scale_factor=scale_factor, + nshots=nshots, + ) + grad_1 = parameter_shift( + circuit=c, + hamiltonian=test_hamiltonian, + parameter_index=1, + scale_factor=scale_factor, + nshots=nshots, + ) + grad_2 = parameter_shift( + circuit=c, + hamiltonian=test_hamiltonian, + parameter_index=2, + scale_factor=scale_factor, + nshots=nshots, + ) - # check of known values - # calculated using tf.GradientTape - backend.assert_allclose(grad_0, grads[0], atol=atol) - backend.assert_allclose(grad_1, grads[1], atol=atol) - backend.assert_allclose(grad_2, grads[2], atol=atol) + # check of known values + # calculated using tf.GradientTape + backend.assert_allclose(grad_0, grads[0], atol=atol) + backend.assert_allclose(grad_1, grads[1], atol=atol) + backend.assert_allclose(grad_2, grads[2], atol=atol) @pytest.mark.parametrize("step_size", [10**-i for i in range(5, 10, 1)])