diff --git a/src/qibo/backends/pytorch.py b/src/qibo/backends/pytorch.py index 3df847c6b4..9bc1909a85 100644 --- a/src/qibo/backends/pytorch.py +++ b/src/qibo/backends/pytorch.py @@ -85,11 +85,9 @@ def cast( """ if dtype is None: dtype = self.dtype - elif isinstance(dtype, self.np.dtype): - dtype = dtype elif isinstance(dtype, type): dtype = torch_dtype_dict[dtype.__name__] - else: + elif not isinstance(dtype, torch.dtype): dtype = torch_dtype_dict[str(dtype)] if isinstance(x, self.np.Tensor): diff --git a/tests/test_gates_gates.py b/tests/test_gates_gates.py index f2d5f1db67..77eda97cc0 100644 --- a/tests/test_gates_gates.py +++ b/tests/test_gates_gates.py @@ -436,9 +436,10 @@ def test_u3(backend, seed_state, seed_observable): backend.cast(np.transpose(np.conj(final_state_decompose))) @ observable @ final_state_decompose, - backend.cast(np.transpose(np.conj(target_state))) @ observable @ target_state, + backend.cast(np.transpose(np.conj(target_state))) + @ observable + @ backend.cast(target_state), ) - assert gates.U3(0, theta, phi, lam).qasm_label == "u3" assert not gates.U3(0, theta, phi, lam).clifford assert gates.U3(0, theta, phi, lam).unitary @@ -527,7 +528,9 @@ def test_cy(backend, controlled_by, seed_state, seed_observable): backend.cast(np.transpose(np.conj(final_state_decompose))) @ observable @ final_state_decompose, - backend.cast(np.transpose(np.conj(target_state))) @ observable @ target_state, + backend.cast(np.transpose(np.conj(target_state))) + @ observable + @ backend.cast(target_state), ) assert gates.CY(0, 1).qasm_label == "cy" @@ -571,7 +574,9 @@ def test_cz(backend, controlled_by, seed_state, seed_observable): backend.cast(np.transpose(np.conj(final_state_decompose))) @ observable @ final_state_decompose, - backend.cast(np.transpose(np.conj(target_state))) @ observable @ target_state, + backend.cast(np.transpose(np.conj(target_state))) + @ observable + @ backend.cast(target_state), ) assert gates.CZ(0, 1).qasm_label == "cz" diff --git a/tests/test_measurements.py b/tests/test_measurements.py index 07ba9d9631..51e5bae07b 100644 --- a/tests/test_measurements.py +++ b/tests/test_measurements.py @@ -76,8 +76,11 @@ def test_measurement_gate(backend, n, nshots): def test_multiple_qubit_measurement_gate(backend): c = models.Circuit(2) c.add(gates.X(0)) - c.add(gates.M(0, 1)) + measure = c.add(gates.M(0, 1)) result = backend.execute_circuit(c, nshots=100) + print(result.frequencies()) + print(result.probabilities()) + # print(measure.samples()) target_binary_samples = np.zeros((100, 2)) target_binary_samples[:, 0] = 1 assert_result(