diff --git a/docs/source/data.rst b/docs/source/data.rst index 3360fd37..62c75fa7 100644 --- a/docs/source/data.rst +++ b/docs/source/data.rst @@ -27,14 +27,14 @@ This module defines standard data structure for pattern commands. .. autoclass:: MeasureUpdate -:mod:`graphix.pauli` module -+++++++++++++++++++++++++++ +:mod:`graphix.fundamentals` module +++++++++++++++++++++++++++++++++++ -This module defines standard data structure for Pauli operators, measurement planes and their transformations. +This module defines standard data structure for Pauli operators. -.. automodule:: graphix.pauli +.. automodule:: graphix.fundamentals -.. currentmodule:: graphix.pauli +.. currentmodule:: graphix.fundamentals .. autoclass:: Axis :members: @@ -51,8 +51,16 @@ This module defines standard data structure for Pauli operators, measurement pla .. autoclass:: Plane :members: -.. autoclass:: Pauli +:mod:`graphix.pauli` module ++++++++++++++++++++++++++++ + +This module defines standard data structure for Pauli operators. + +.. automodule:: graphix.pauli +.. currentmodule:: graphix.pauli + +.. autoclass:: Pauli :mod:`graphix.instruction` module +++++++++++++++++++++++++++++++++ @@ -96,8 +104,3 @@ This module defines standard data structure for gate seqence (circuit model) use .. currentmodule:: graphix.states .. autoclass:: State - - - - - diff --git a/examples/visualization.py b/examples/visualization.py index 0d2c057a..71bab0c6 100644 --- a/examples/visualization.py +++ b/examples/visualization.py @@ -22,7 +22,7 @@ import numpy as np from graphix import Circuit -from graphix.pauli import Plane +from graphix.fundamentals import Plane circuit = Circuit(3) circuit.cnot(0, 1) diff --git a/graphix/_db.py b/graphix/_db.py index 648f2325..3dd4f1d4 100644 --- a/graphix/_db.py +++ b/graphix/_db.py @@ -2,54 +2,41 @@ from __future__ import annotations -from typing import ClassVar, Literal, NamedTuple, TypeVar +from typing import NamedTuple import numpy as np -import numpy.typing as npt -_T = TypeVar("_T", bound=np.generic) +from graphix import utils +from graphix.fundamentals import IXYZ, Sign +from graphix.ops import Ops + +# 24 unique 1-qubit Clifford gates +_C0 = Ops.I # I +_C1 = Ops.X # X +_C2 = Ops.Y # Y +_C3 = Ops.Z # Z +_C4 = Ops.S # S = \sqrt{Z} +_C5 = Ops.SDG # SDG = S^{\dagger} +_C6 = Ops.H # H +_C7 = utils.lock(np.asarray([[1, -1j], [-1j, 1]]) / np.sqrt(2)) # \sqrt{iX} +_C8 = utils.lock(np.asarray([[1, -1], [1, 1]]) / np.sqrt(2)) # \sqrt{iY} +_C9 = utils.lock(np.asarray([[0, 1 - 1j], [-1 - 1j, 0]]) / np.sqrt(2)) # sqrt{I} +_C10 = utils.lock(np.asarray([[0, -1 - 1j], [1 - 1j, 0]]) / np.sqrt(2)) # sqrt{-I} +_C11 = utils.lock(np.asarray([[1, -1], [-1, -1]]) / np.sqrt(2)) # sqrt{I} +_C12 = utils.lock(np.asarray([[-1, -1], [1, -1]]) / np.sqrt(2)) # sqrt{-iY} +_C13 = utils.lock(np.asarray([[1j, -1], [1, -1j]]) / np.sqrt(2)) # sqrt{-I} +_C14 = utils.lock(np.asarray([[1j, 1], [-1, -1j]]) / np.sqrt(2)) # sqrt{-I} +_C15 = utils.lock(np.asarray([[-1, -1j], [-1j, -1]]) / np.sqrt(2)) # sqrt{-iX} +_C16 = utils.lock(np.asarray([[-1 + 1j, 1 + 1j], [-1 + 1j, -1 - 1j]]) / 2) # I^(1/3) +_C17 = utils.lock(np.asarray([[-1 + 1j, -1 - 1j], [1 - 1j, -1 - 1j]]) / 2) # I^(1/3) +_C18 = utils.lock(np.asarray([[1 + 1j, 1 - 1j], [-1 - 1j, 1 - 1j]]) / 2) # I^(1/3) +_C19 = utils.lock(np.asarray([[-1 - 1j, 1 - 1j], [-1 - 1j, -1 + 1j]]) / 2) # I^(1/3) +_C20 = utils.lock(np.asarray([[-1 - 1j, -1 - 1j], [1 - 1j, -1 + 1j]]) / 2) # I^(1/3) +_C21 = utils.lock(np.asarray([[-1 + 1j, -1 + 1j], [1 + 1j, -1 - 1j]]) / 2) # I^(1/3) +_C22 = utils.lock(np.asarray([[1 + 1j, -1 - 1j], [1 - 1j, 1 - 1j]]) / 2) # I^(1/3) +_C23 = utils.lock(np.asarray([[-1 + 1j, 1 - 1j], [-1 - 1j, -1 - 1j]]) / 2) # I^(1/3) -def _lock(data: npt.NDArray[_T]) -> npt.NDArray[np.complex128]: - """Create a true immutable view. - - data must not have aliasing references, otherwise users can still turn on writeable flag of m. - """ - m = data.astype(np.complex128) - m.flags.writeable = False - v = m.view() - assert not v.flags.writeable - return v - - -# 24 Unique 1-qubit Clifford gates -_C0 = _lock(np.asarray([[1, 0], [0, 1]])) # identity -_C1 = _lock(np.asarray([[0, 1], [1, 0]])) # X -_C2 = _lock(np.asarray([[0, -1j], [1j, 0]])) # Y -_C3 = _lock(np.asarray([[1, 0], [0, -1]])) # Z -_C4 = _lock(np.asarray([[1, 0], [0, 1j]])) # S = \sqrt{Z} -_C5 = _lock(np.asarray([[1, 0], [0, -1j]])) # S dagger -_C6 = _lock(np.asarray([[1, 1], [1, -1]]) / np.sqrt(2)) # Hadamard -_C7 = _lock(np.asarray([[1, -1j], [-1j, 1]]) / np.sqrt(2)) # \sqrt{iX} -_C8 = _lock(np.asarray([[1, -1], [1, 1]]) / np.sqrt(2)) # \sqrt{iY} -_C9 = _lock(np.asarray([[0, 1 - 1j], [-1 - 1j, 0]]) / np.sqrt(2)) # sqrt{I} -_C10 = _lock(np.asarray([[0, -1 - 1j], [1 - 1j, 0]]) / np.sqrt(2)) # sqrt{-I} -_C11 = _lock(np.asarray([[1, -1], [-1, -1]]) / np.sqrt(2)) # sqrt{I} -_C12 = _lock(np.asarray([[-1, -1], [1, -1]]) / np.sqrt(2)) # sqrt{-iY} -_C13 = _lock(np.asarray([[1j, -1], [1, -1j]]) / np.sqrt(2)) # sqrt{-I} -_C14 = _lock(np.asarray([[1j, 1], [-1, -1j]]) / np.sqrt(2)) # sqrt{-I} -_C15 = _lock(np.asarray([[-1, -1j], [-1j, -1]]) / np.sqrt(2)) # sqrt{-iX} -_C16 = _lock(np.asarray([[-1 + 1j, 1 + 1j], [-1 + 1j, -1 - 1j]]) / 2) # I^(1/3) -_C17 = _lock(np.asarray([[-1 + 1j, -1 - 1j], [1 - 1j, -1 - 1j]]) / 2) # I^(1/3) -_C18 = _lock(np.asarray([[1 + 1j, 1 - 1j], [-1 - 1j, 1 - 1j]]) / 2) # I^(1/3) -_C19 = _lock(np.asarray([[-1 - 1j, 1 - 1j], [-1 - 1j, -1 + 1j]]) / 2) # I^(1/3) -_C20 = _lock(np.asarray([[-1 - 1j, -1 - 1j], [1 - 1j, -1 + 1j]]) / 2) # I^(1/3) -_C21 = _lock(np.asarray([[-1 + 1j, -1 + 1j], [1 + 1j, -1 - 1j]]) / 2) # I^(1/3) -_C22 = _lock(np.asarray([[1 + 1j, -1 - 1j], [1 - 1j, 1 - 1j]]) / 2) # I^(1/3) -_C23 = _lock(np.asarray([[-1 + 1j, 1 - 1j], [-1 - 1j, -1 - 1j]]) / 2) # I^(1/3) - - -# list of unique 1-qubit Clifford gates CLIFFORD = ( _C0, _C1, @@ -77,7 +64,7 @@ def _lock(data: npt.NDArray[_T]) -> npt.NDArray[np.complex128]: _C23, ) -# readable labels for the 1-qubit Clifford +# Human-readable labels CLIFFORD_LABEL = ( "I", "X", @@ -106,8 +93,7 @@ def _lock(data: npt.NDArray[_T]) -> npt.NDArray[np.complex128]: "I^{1/3}", ) -# Multiplying single-qubit Clifford gates result in a single-qubit Clifford gate. -# CLIFFORD_MUL provides the result of Clifford gate multiplications by Clifford index (see above). +# Clifford(CLIFFORD_MUL[i][j]) ~ CLIFFORD[i] @ CLIFFORD[j] (up to phase) CLIFFORD_MUL = ( (0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23), (1, 0, 3, 2, 9, 10, 8, 15, 6, 4, 5, 12, 11, 14, 13, 7, 19, 18, 17, 16, 22, 23, 20, 21), @@ -135,53 +121,54 @@ def _lock(data: npt.NDArray[_T]) -> npt.NDArray[np.complex128]: (23, 22, 21, 20, 14, 15, 10, 8, 9, 7, 13, 5, 4, 6, 12, 11, 2, 3, 0, 1, 17, 16, 19, 18), ) -# Conjugation of Clifford gates result in a Clifford gate. -# CLIFFORD_CONJ provides the Clifford index of conjugated matrix. -# Example (S and S dagger): CLIFFORD_CONJ[4] = 5 -# WARNING: CLIFFORD[i].conj().T is not necessarily equal to -# CLIFFORD[CLIFFORD_CONJ[i]] in general: the phase may differ. -# For instance, CLIFFORD[7].conj().T = - CLIFFORD[CLIFFORD_CONJ[7]] +# Clifford(CLIFFORD_CONJ[i]) ~ CLIFFORD[i].H (up to phase) CLIFFORD_CONJ = (0, 1, 2, 3, 5, 4, 6, 15, 12, 9, 10, 11, 8, 13, 14, 7, 20, 22, 23, 21, 16, 19, 17, 18) -class _CliffordMeasure(NamedTuple): - """NamedTuple just for documentation purposes.""" +class _CM(NamedTuple): + """Pauli string and sign.""" - pstr: Literal["X", "Y", "Z"] - sign: Literal[-1, +1] + pstr: IXYZ + sign: Sign + + +class _CMTuple(NamedTuple): + x: _CM + y: _CM + z: _CM # Conjugation of Pauli gates P with Clifford gate C, # i.e. C @ P @ C^dagger result in Pauli group, i.e. {\pm} \times {X, Y, Z}. # CLIFFORD_MEASURE contains the effect of Clifford conjugation of Pauli gates. CLIFFORD_MEASURE = ( - (_CliffordMeasure("X", +1), _CliffordMeasure("Y", +1), _CliffordMeasure("Z", +1)), - (_CliffordMeasure("X", +1), _CliffordMeasure("Y", -1), _CliffordMeasure("Z", -1)), - (_CliffordMeasure("X", -1), _CliffordMeasure("Y", +1), _CliffordMeasure("Z", -1)), - (_CliffordMeasure("X", -1), _CliffordMeasure("Y", -1), _CliffordMeasure("Z", +1)), - (_CliffordMeasure("Y", -1), _CliffordMeasure("X", +1), _CliffordMeasure("Z", +1)), - (_CliffordMeasure("Y", +1), _CliffordMeasure("X", -1), _CliffordMeasure("Z", +1)), - (_CliffordMeasure("Z", +1), _CliffordMeasure("Y", -1), _CliffordMeasure("X", +1)), - (_CliffordMeasure("X", +1), _CliffordMeasure("Z", -1), _CliffordMeasure("Y", +1)), - (_CliffordMeasure("Z", +1), _CliffordMeasure("Y", +1), _CliffordMeasure("X", -1)), - (_CliffordMeasure("Y", -1), _CliffordMeasure("X", -1), _CliffordMeasure("Z", -1)), - (_CliffordMeasure("Y", +1), _CliffordMeasure("X", +1), _CliffordMeasure("Z", -1)), - (_CliffordMeasure("Z", -1), _CliffordMeasure("Y", -1), _CliffordMeasure("X", -1)), - (_CliffordMeasure("Z", -1), _CliffordMeasure("Y", +1), _CliffordMeasure("X", +1)), - (_CliffordMeasure("X", -1), _CliffordMeasure("Z", -1), _CliffordMeasure("Y", -1)), - (_CliffordMeasure("X", -1), _CliffordMeasure("Z", +1), _CliffordMeasure("Y", +1)), - (_CliffordMeasure("X", +1), _CliffordMeasure("Z", +1), _CliffordMeasure("Y", -1)), - (_CliffordMeasure("Z", +1), _CliffordMeasure("X", +1), _CliffordMeasure("Y", +1)), - (_CliffordMeasure("Z", -1), _CliffordMeasure("X", +1), _CliffordMeasure("Y", -1)), - (_CliffordMeasure("Z", -1), _CliffordMeasure("X", -1), _CliffordMeasure("Y", +1)), - (_CliffordMeasure("Z", +1), _CliffordMeasure("X", -1), _CliffordMeasure("Y", -1)), - (_CliffordMeasure("Y", +1), _CliffordMeasure("Z", +1), _CliffordMeasure("X", +1)), - (_CliffordMeasure("Y", -1), _CliffordMeasure("Z", -1), _CliffordMeasure("X", +1)), - (_CliffordMeasure("Y", +1), _CliffordMeasure("Z", -1), _CliffordMeasure("X", -1)), - (_CliffordMeasure("Y", -1), _CliffordMeasure("Z", +1), _CliffordMeasure("X", -1)), + _CMTuple(_CM(IXYZ.X, Sign.PLUS), _CM(IXYZ.Y, Sign.PLUS), _CM(IXYZ.Z, Sign.PLUS)), + _CMTuple(_CM(IXYZ.X, Sign.PLUS), _CM(IXYZ.Y, Sign.MINUS), _CM(IXYZ.Z, Sign.MINUS)), + _CMTuple(_CM(IXYZ.X, Sign.MINUS), _CM(IXYZ.Y, Sign.PLUS), _CM(IXYZ.Z, Sign.MINUS)), + _CMTuple(_CM(IXYZ.X, Sign.MINUS), _CM(IXYZ.Y, Sign.MINUS), _CM(IXYZ.Z, Sign.PLUS)), + _CMTuple(_CM(IXYZ.Y, Sign.MINUS), _CM(IXYZ.X, Sign.PLUS), _CM(IXYZ.Z, Sign.PLUS)), + _CMTuple(_CM(IXYZ.Y, Sign.PLUS), _CM(IXYZ.X, Sign.MINUS), _CM(IXYZ.Z, Sign.PLUS)), + _CMTuple(_CM(IXYZ.Z, Sign.PLUS), _CM(IXYZ.Y, Sign.MINUS), _CM(IXYZ.X, Sign.PLUS)), + _CMTuple(_CM(IXYZ.X, Sign.PLUS), _CM(IXYZ.Z, Sign.MINUS), _CM(IXYZ.Y, Sign.PLUS)), + _CMTuple(_CM(IXYZ.Z, Sign.PLUS), _CM(IXYZ.Y, Sign.PLUS), _CM(IXYZ.X, Sign.MINUS)), + _CMTuple(_CM(IXYZ.Y, Sign.MINUS), _CM(IXYZ.X, Sign.MINUS), _CM(IXYZ.Z, Sign.MINUS)), + _CMTuple(_CM(IXYZ.Y, Sign.PLUS), _CM(IXYZ.X, Sign.PLUS), _CM(IXYZ.Z, Sign.MINUS)), + _CMTuple(_CM(IXYZ.Z, Sign.MINUS), _CM(IXYZ.Y, Sign.MINUS), _CM(IXYZ.X, Sign.MINUS)), + _CMTuple(_CM(IXYZ.Z, Sign.MINUS), _CM(IXYZ.Y, Sign.PLUS), _CM(IXYZ.X, Sign.PLUS)), + _CMTuple(_CM(IXYZ.X, Sign.MINUS), _CM(IXYZ.Z, Sign.MINUS), _CM(IXYZ.Y, Sign.MINUS)), + _CMTuple(_CM(IXYZ.X, Sign.MINUS), _CM(IXYZ.Z, Sign.PLUS), _CM(IXYZ.Y, Sign.PLUS)), + _CMTuple(_CM(IXYZ.X, Sign.PLUS), _CM(IXYZ.Z, Sign.PLUS), _CM(IXYZ.Y, Sign.MINUS)), + _CMTuple(_CM(IXYZ.Z, Sign.PLUS), _CM(IXYZ.X, Sign.PLUS), _CM(IXYZ.Y, Sign.PLUS)), + _CMTuple(_CM(IXYZ.Z, Sign.MINUS), _CM(IXYZ.X, Sign.PLUS), _CM(IXYZ.Y, Sign.MINUS)), + _CMTuple(_CM(IXYZ.Z, Sign.MINUS), _CM(IXYZ.X, Sign.MINUS), _CM(IXYZ.Y, Sign.PLUS)), + _CMTuple(_CM(IXYZ.Z, Sign.PLUS), _CM(IXYZ.X, Sign.MINUS), _CM(IXYZ.Y, Sign.MINUS)), + _CMTuple(_CM(IXYZ.Y, Sign.PLUS), _CM(IXYZ.Z, Sign.PLUS), _CM(IXYZ.X, Sign.PLUS)), + _CMTuple(_CM(IXYZ.Y, Sign.MINUS), _CM(IXYZ.Z, Sign.MINUS), _CM(IXYZ.X, Sign.PLUS)), + _CMTuple(_CM(IXYZ.Y, Sign.PLUS), _CM(IXYZ.Z, Sign.MINUS), _CM(IXYZ.X, Sign.MINUS)), + _CMTuple(_CM(IXYZ.Y, Sign.MINUS), _CM(IXYZ.Z, Sign.PLUS), _CM(IXYZ.X, Sign.MINUS)), ) -# Decomposition of Clifford gates with H, S and Z. +# Decomposition of Clifford gates with H, S and Z (up to phase). CLIFFORD_HSZ_DECOMPOSITION = ( (0,), (6, 3, 6), @@ -237,59 +224,3 @@ class _CliffordMeasure(NamedTuple): ("h", "x", "sdg"), ("h", "x", "s"), ) - - -class WellKnownMatrix: - """Collection of well-known matrices.""" - - I: ClassVar = _C0 - X: ClassVar = _C1 - Y: ClassVar = _C2 - Z: ClassVar = _C3 - S: ClassVar = _C4 - SDG: ClassVar = _C5 - H: ClassVar = _C6 - CZ: ClassVar = _lock( - np.asarray( - [ - [1, 0, 0, 0], - [0, 1, 0, 0], - [0, 0, 1, 0], - [0, 0, 0, -1], - ], - ) - ) - CNOT: ClassVar = _lock( - np.asarray( - [ - [1, 0, 0, 0], - [0, 1, 0, 0], - [0, 0, 0, 1], - [0, 0, 1, 0], - ], - ) - ) - SWAP: ClassVar = _lock( - np.asarray( - [ - [1, 0, 0, 0], - [0, 0, 1, 0], - [0, 1, 0, 0], - [0, 0, 0, 1], - ], - ) - ) - CCX: ClassVar = _lock( - np.asarray( - [ - [1, 0, 0, 0, 0, 0, 0, 0], - [0, 1, 0, 0, 0, 0, 0, 0], - [0, 0, 1, 0, 0, 0, 0, 0], - [0, 0, 0, 1, 0, 0, 0, 0], - [0, 0, 0, 0, 1, 0, 0, 0], - [0, 0, 0, 0, 0, 1, 0, 0], - [0, 0, 0, 0, 0, 0, 0, 1], - [0, 0, 0, 0, 0, 0, 1, 0], - ], - ) - ) diff --git a/graphix/clifford.py b/graphix/clifford.py index 5cc01932..713e5e7a 100644 --- a/graphix/clifford.py +++ b/graphix/clifford.py @@ -3,9 +3,12 @@ from __future__ import annotations import copy -import dataclasses +import math from enum import Enum -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any + +import numpy as np +import typing_extensions from graphix._db import ( CLIFFORD, @@ -16,25 +19,14 @@ CLIFFORD_MUL, CLIFFORD_TO_QASM3, ) -from graphix.pauli import IXYZ, ComplexUnit, Pauli, Sign +from graphix.fundamentals import IXYZ, ComplexUnit +from graphix.measurements import Domains +from graphix.pauli import Pauli if TYPE_CHECKING: - import numpy as np import numpy.typing as npt -@dataclasses.dataclass -class Domains: - """ - Represent `X^sZ^t` where s and t are XOR of results from given sets of indices. - - This representation is used in `Clifford.commute_domains`. - """ - - s_domain: set[int] - t_domain: set[int] - - class Clifford(Enum): """Clifford gate.""" @@ -77,9 +69,35 @@ def matrix(self) -> npt.NDArray[np.complex128]: """Return the matrix of the Clifford gate.""" return CLIFFORD[self.value] + @staticmethod + def try_from_matrix(mat: npt.NDArray[Any]) -> Clifford | None: + """Find the Clifford gate from the matrix. + + Return `None` if not found. + + Notes + ----- + Global phase is ignored. + """ + if mat.shape != (2, 2): + return None + for ci in Clifford: + mi = ci.matrix + for piv, piv_ in zip(mat.flat, mi.flat): + if math.isclose(abs(piv), 0): + continue + if math.isclose(abs(piv_), 0): + continue + if np.allclose(mat / piv, mi / piv_): + return ci + return None + def __repr__(self) -> str: """Return the Clifford expression on the form of HSZ decomposition.""" - return " @ ".join([f"Clifford.{gate}" for gate in self.hsz]) + formula = " @ ".join([f"Clifford.{gate}" for gate in self.hsz]) + if len(self.hsz) == 1: + return formula + return f"({formula})" def __str__(self) -> str: """Return the name of the Clifford gate.""" @@ -111,8 +129,15 @@ def measure(self, pauli: Pauli) -> Pauli: if pauli.symbol == IXYZ.I: return copy.deepcopy(pauli) table = CLIFFORD_MEASURE[self.value] - symbol, sign = table[pauli.symbol.value] - return pauli.unit * Pauli(IXYZ[symbol], ComplexUnit(Sign(sign), False)) + if pauli.symbol == IXYZ.X: + symbol, sign = table.x + elif pauli.symbol == IXYZ.Y: + symbol, sign = table.y + elif pauli.symbol == IXYZ.Z: + symbol, sign = table.z + else: + typing_extensions.assert_never(pauli.symbol) + return pauli.unit * Pauli(symbol, ComplexUnit.from_properties(sign=sign)) def commute_domains(self, domains: Domains) -> Domains: """ diff --git a/graphix/command.py b/graphix/command.py index 6bd30f5b..11c13c60 100644 --- a/graphix/command.py +++ b/graphix/command.py @@ -10,9 +10,11 @@ import numpy as np -from graphix import type_utils -from graphix.clifford import Clifford, Domains -from graphix.pauli import Pauli, Plane, Sign +from graphix import utils +from graphix.clifford import Clifford +from graphix.fundamentals import Plane, Sign +from graphix.measurements import Domains +from graphix.pauli import Pauli from graphix.states import BasicStates, State Node = int @@ -36,7 +38,7 @@ class _KindChecker: def __init_subclass__(cls) -> None: super().__init_subclass__() - type_utils.check_kind(cls, {"CommandKind": CommandKind, "Clifford": Clifford}) + utils.check_kind(cls, {"CommandKind": CommandKind, "Clifford": Clifford}) @dataclasses.dataclass @@ -159,12 +161,9 @@ def compute(plane: Plane, s: bool, t: bool, clifford_gate: Clifford) -> MeasureU cos_pauli = clifford_gate.measure(Pauli.from_axis(plane.cos)) sin_pauli = clifford_gate.measure(Pauli.from_axis(plane.sin)) exchange = cos_pauli.axis != new_plane.cos - if exchange == (cos_pauli.unit.sign == sin_pauli.unit.sign): - coeff = -1 - else: - coeff = 1 + coeff = -1 if exchange == (cos_pauli.unit.sign == sin_pauli.unit.sign) else 1 add_term: float = 0 - if cos_pauli.unit.sign == Sign.Minus: + if cos_pauli.unit.sign == Sign.MINUS: add_term += np.pi if exchange: add_term = np.pi / 2 - add_term diff --git a/graphix/fundamentals.py b/graphix/fundamentals.py new file mode 100644 index 00000000..1fecf0e1 --- /dev/null +++ b/graphix/fundamentals.py @@ -0,0 +1,307 @@ +"""Fundamental components related to quantum mechanics.""" + +from __future__ import annotations + +import enum +import math +import sys +import typing +from enum import Enum +from typing import TYPE_CHECKING, SupportsComplex, SupportsFloat, SupportsIndex + +import typing_extensions + +from graphix.ops import Ops + +if TYPE_CHECKING: + import numpy as np + import numpy.typing as npt + + +if sys.version_info >= (3, 10): + SupportsComplexCtor = SupportsComplex | SupportsFloat | SupportsIndex | complex +else: # pragma: no cover + from typing import Union + + SupportsComplexCtor = Union[SupportsComplex, SupportsFloat, SupportsIndex, complex] + + +class Sign(Enum): + """Sign, plus or minus.""" + + PLUS = 1 + MINUS = -1 + + def __str__(self) -> str: + """Return `+` or `-`.""" + if self == Sign.PLUS: + return "+" + return "-" + + @staticmethod + def plus_if(b: bool) -> Sign: + """Return `+` if `b` is `True`, `-` otherwise.""" + if b: + return Sign.PLUS + return Sign.MINUS + + @staticmethod + def minus_if(b: bool) -> Sign: + """Return `-` if `b` is `True`, `+` otherwise.""" + if b: + return Sign.MINUS + return Sign.PLUS + + def __neg__(self) -> Sign: + """Swap the sign.""" + return Sign.minus_if(self == Sign.PLUS) + + @typing.overload + def __mul__(self, other: Sign) -> Sign: ... + + @typing.overload + def __mul__(self, other: int) -> int: ... + + @typing.overload + def __mul__(self, other: float) -> float: ... + + @typing.overload + def __mul__(self, other: complex) -> complex: ... + + def __mul__(self, other: Sign | complex) -> Sign | int | float | complex: + """Multiply the sign with another sign or a number.""" + if isinstance(other, Sign): + return Sign.plus_if(self == other) + if isinstance(other, int): + return int(self) * other + if isinstance(other, float): + return float(self) * other + if isinstance(other, complex): + return complex(self) * other + return NotImplemented + + @typing.overload + def __rmul__(self, other: int) -> int: ... + + @typing.overload + def __rmul__(self, other: float) -> float: ... + + @typing.overload + def __rmul__(self, other: complex) -> complex: ... + + def __rmul__(self, other: complex) -> int | float | complex: + """Multiply the sign with a number.""" + if isinstance(other, (int, float, complex)): + return self.__mul__(other) + return NotImplemented + + def __int__(self) -> int: + """Return `1` for `+` and `-1` for `-`.""" + # mypy does not infer the return type correctly + return self.value # type: ignore[no-any-return] + + def __float__(self) -> float: + """Return `1.0` for `+` and `-1.0` for `-`.""" + return float(self.value) + + def __complex__(self) -> complex: + """Return `1.0 + 0j` for `+` and `-1.0 + 0j` for `-`.""" + return complex(self.value) + + +class ComplexUnit(Enum): + """ + Complex unit: 1, -1, j, -j. + + Complex units can be multiplied with other complex units, + with Python constants 1, -1, 1j, -1j, and can be negated. + """ + + # HACK: complex(u) == (1j) ** u.value for all u in ComplexUnit. + + ONE = 0 + J = 1 + MINUS_ONE = 2 + MINUS_J = 3 + + @staticmethod + def try_from(value: ComplexUnit | SupportsComplexCtor) -> ComplexUnit | None: + """Return the ComplexUnit instance if the value is compatible, None otherwise.""" + if isinstance(value, ComplexUnit): + return value + try: + value = complex(value) + except Exception: + return None + if value == 1: + return ComplexUnit.ONE + if value == -1: + return ComplexUnit.MINUS_ONE + if value == 1j: + return ComplexUnit.J + if value == -1j: + return ComplexUnit.MINUS_J + return None + + @staticmethod + def from_properties(*, sign: Sign = Sign.PLUS, is_imag: bool = False) -> ComplexUnit: + """Construct ComplexUnit from its properties.""" + osign = 0 if sign == Sign.PLUS else 2 + oimag = 1 if is_imag else 0 + return ComplexUnit(osign + oimag) + + @property + def sign(self) -> Sign: + """Return the sign.""" + return Sign.plus_if(self.value < 2) + + @property + def is_imag(self) -> bool: + """Return `True` if `j` or `-j`.""" + return bool(self.value % 2) + + def __complex__(self) -> complex: + """Return the unit as complex number.""" + ret: complex = 1j**self.value + return ret + + def __str__(self) -> str: + """Return a string representation of the unit.""" + result = "1j" if self.is_imag else "1" + if self.sign == Sign.MINUS: + result = "-" + result + return result + + def __mul__(self, other: ComplexUnit | SupportsComplexCtor) -> ComplexUnit: + """Multiply the complex unit with a number.""" + if isinstance(other, ComplexUnit): + return ComplexUnit((self.value + other.value) % 4) + if other_ := ComplexUnit.try_from(other): + return self.__mul__(other_) + return NotImplemented + + def __rmul__(self, other: SupportsComplexCtor) -> ComplexUnit: + """Multiply the complex unit with a number.""" + return self.__mul__(other) + + def __neg__(self) -> ComplexUnit: + """Return the opposite of the complex unit.""" + return ComplexUnit((self.value + 2) % 4) + + +class IXYZ(Enum): + """I, X, Y or Z.""" + + I = enum.auto() + X = enum.auto() + Y = enum.auto() + Z = enum.auto() + + @property + def matrix(self) -> npt.NDArray[np.complex128]: + """Return the matrix representation.""" + if self == IXYZ.I: + return Ops.I + if self == IXYZ.X: + return Ops.X + if self == IXYZ.Y: + return Ops.Y + if self == IXYZ.Z: + return Ops.Z + typing_extensions.assert_never(self) + + +class Axis(Enum): + """Axis: `X`, `Y` or `Z`.""" + + X = enum.auto() + Y = enum.auto() + Z = enum.auto() + + @property + def matrix(self) -> npt.NDArray[np.complex128]: + """Return the matrix representation.""" + if self == Axis.X: + return Ops.X + if self == Axis.Y: + return Ops.Y + if self == Axis.Z: + return Ops.Z + typing_extensions.assert_never(self) + + +class Plane(Enum): + # TODO: Refactor using match + """Plane: `XY`, `YZ` or `XZ`.""" + + XY = enum.auto() + YZ = enum.auto() + XZ = enum.auto() + + @property + def axes(self) -> tuple[Axis, Axis]: + """Return the pair of axes that carry the plane.""" + if self == Plane.XY: + return (Axis.X, Axis.Y) + if self == Plane.YZ: + return (Axis.Y, Axis.Z) + if self == Plane.XZ: + return (Axis.X, Axis.Z) + typing_extensions.assert_never(self) + + @property + def orth(self) -> Axis: + """Return the axis orthogonal to the plane.""" + if self == Plane.XY: + return Axis.Z + if self == Plane.YZ: + return Axis.X + if self == Plane.XZ: + return Axis.Y + typing_extensions.assert_never(self) + + @property + def cos(self) -> Axis: + """Return the axis of the plane that conventionally carries the cos.""" + if self == Plane.XY: + return Axis.X + if self == Plane.YZ: + return Axis.Z # former convention was Y + if self == Plane.XZ: + return Axis.Z # former convention was X + typing_extensions.assert_never(self) + + @property + def sin(self) -> Axis: + """Return the axis of the plane that conventionally carries the sin.""" + if self == Plane.XY: + return Axis.Y + if self == Plane.YZ: + return Axis.Y # former convention was Z + if self == Plane.XZ: + return Axis.X # former convention was Z + typing_extensions.assert_never(self) + + def polar(self, angle: float) -> tuple[float, float, float]: + """Return the Cartesian coordinates of the point of module 1 at the given angle, following the conventional orientation for cos and sin.""" + pp = (self.cos, self.sin) + if pp == (Axis.X, Axis.Y): + return (math.cos(angle), math.sin(angle), 0) + if pp == (Axis.Z, Axis.Y): + return (0, math.sin(angle), math.cos(angle)) + if pp == (Axis.Z, Axis.X): + return (math.sin(angle), 0, math.cos(angle)) + raise RuntimeError("Unreachable.") # pragma: no cover + + @staticmethod + def from_axes(a: Axis, b: Axis) -> Plane: + """Return the plane carried by the given axes.""" + ab = {a, b} + if ab == {Axis.X, Axis.Y}: + return Plane.XY + if ab == {Axis.Y, Axis.Z}: + return Plane.YZ + if ab == {Axis.X, Axis.Z}: + return Plane.XZ + assert a == b + raise ValueError(f"Cannot make a plane giving the same axis {a} twice.") diff --git a/graphix/generator.py b/graphix/generator.py index 17faf9ff..cb5a62c6 100644 --- a/graphix/generator.py +++ b/graphix/generator.py @@ -3,9 +3,9 @@ from __future__ import annotations from graphix.command import E, M, N, X, Z +from graphix.fundamentals import Plane from graphix.gflow import find_flow, find_gflow, find_odd_neighbor, get_layers from graphix.pattern import Pattern -from graphix.pauli import Plane def generate_from_graph(graph, angles, inputs, outputs, meas_planes=None): diff --git a/graphix/gflow.py b/graphix/gflow.py index 56baf1a4..2f17b39a 100644 --- a/graphix/gflow.py +++ b/graphix/gflow.py @@ -20,10 +20,10 @@ import numpy as np import sympy as sp -from graphix import pauli +from graphix import utils from graphix.command import CommandKind +from graphix.fundamentals import Plane from graphix.linalg import MatGF2 -from graphix.pauli import Plane if TYPE_CHECKING: from graphix.pattern import Pattern @@ -1372,18 +1372,18 @@ def get_pauli_nodes( l_x, l_y, l_z = set(), set(), set() for node, plane in meas_planes.items(): if plane == Plane.XY: - if pauli.is_int(meas_angles[node]): # measurement angle is integer + if utils.is_integer(meas_angles[node]): # measurement angle is integer l_x |= {node} - elif pauli.is_int(2 * meas_angles[node]): # measurement angle is half integer + elif utils.is_integer(2 * meas_angles[node]): # measurement angle is half integer l_y |= {node} elif plane == Plane.XZ: - if pauli.is_int(meas_angles[node]): + if utils.is_integer(meas_angles[node]): l_z |= {node} - elif pauli.is_int(2 * meas_angles[node]): + elif utils.is_integer(2 * meas_angles[node]): l_x |= {node} elif plane == Plane.YZ: - if pauli.is_int(meas_angles[node]): + if utils.is_integer(meas_angles[node]): l_y |= {node} - elif pauli.is_int(2 * meas_angles[node]): + elif utils.is_integer(2 * meas_angles[node]): l_z |= {node} return l_x, l_y, l_z diff --git a/graphix/instruction.py b/graphix/instruction.py index 2759f92c..42fa6044 100644 --- a/graphix/instruction.py +++ b/graphix/instruction.py @@ -8,8 +8,8 @@ from enum import Enum from typing import ClassVar, Literal, Union -from graphix import type_utils -from graphix.pauli import Plane +from graphix import utils +from graphix.fundamentals import Plane class InstructionKind(Enum): @@ -39,7 +39,7 @@ class _KindChecker: def __init_subclass__(cls) -> None: super().__init_subclass__() - type_utils.check_kind(cls, {"InstructionKind": InstructionKind, "Plane": Plane}) + utils.check_kind(cls, {"InstructionKind": InstructionKind, "Plane": Plane}) @dataclasses.dataclass diff --git a/graphix/measurements.py b/graphix/measurements.py new file mode 100644 index 00000000..89626387 --- /dev/null +++ b/graphix/measurements.py @@ -0,0 +1,63 @@ +"""Data structure for single-qubit measurements in MBQC.""" + +from __future__ import annotations + +import dataclasses +import math +from typing import NamedTuple + +from graphix import utils +from graphix.fundamentals import Axis, Plane, Sign + + +@dataclasses.dataclass +class Domains: + """Represent `X^sZ^t` where s and t are XOR of results from given sets of indices.""" + + s_domain: set[int] + t_domain: set[int] + + +class Measurement(NamedTuple): + """An MBQC measurement. + + :param angle: the angle of the measurement. Should be between [0, 2) + :param plane: the measurement plane + """ + + angle: float + plane: Plane + + def isclose(self, other: Measurement, rel_tol: float = 1e-09, abs_tol: float = 0.0) -> bool: + """Compare if two measurements have the same plane and their angles are close. + + Example + ------- + >>> from graphix.opengraph import Measurement + >>> from graphix.fundamentals import Plane + >>> Measurement(0.0, Plane.XY).isclose(Measurement(0.0, Plane.XY)) + True + >>> Measurement(0.0, Plane.XY).isclose(Measurement(0.0, Plane.YZ)) + False + >>> Measurement(0.1, Plane.XY).isclose(Measurement(0.0, Plane.XY)) + False + """ + return math.isclose(self.angle, other.angle, rel_tol=rel_tol, abs_tol=abs_tol) and self.plane == other.plane + + +class PauliMeasurement(NamedTuple): + """Pauli measurement.""" + + axis: Axis + sign: Sign + + @staticmethod + def try_from(plane: Plane, angle: float) -> PauliMeasurement | None: + """Return the Pauli measurement description if a given measure is Pauli.""" + angle_double = 2 * angle + if not utils.is_integer(angle_double): + return None + angle_double_mod_4 = int(angle_double) % 4 + axis = plane.cos if angle_double_mod_4 % 2 == 0 else plane.sin + sign = Sign.minus_if(angle_double_mod_4 >= 2) + return PauliMeasurement(axis, sign) diff --git a/graphix/opengraph.py b/graphix/opengraph.py index 82de7ff9..05eba499 100644 --- a/graphix/opengraph.py +++ b/graphix/opengraph.py @@ -2,45 +2,16 @@ from __future__ import annotations -import math from dataclasses import dataclass from typing import TYPE_CHECKING import networkx as nx from graphix.generator import generate_from_graph +from graphix.measurements import Measurement if TYPE_CHECKING: from graphix.pattern import Pattern - from graphix.pauli import Plane - - -@dataclass(frozen=True) -class Measurement: - """An MBQC measurement. - - :param angle: the angle of the measurement. Should be between [0, 2) - :param plane: the measurement plane - """ - - angle: float - plane: Plane - - def isclose(self, other: Measurement, rel_tol: float = 1e-09, abs_tol: float = 0.0) -> bool: - """Compare if two measurements have the same plane and their angles are close. - - Example - ------- - >>> from graphix.opengraph import Measurement - >>> from graphix.pauli import Plane - >>> Measurement(0.0, Plane.XY).isclose(Measurement(0.0, Plane.XY)) - True - >>> Measurement(0.0, Plane.XY).isclose(Measurement(0.0, Plane.YZ)) - False - >>> Measurement(0.1, Plane.XY).isclose(Measurement(0.0, Plane.XY)) - False - """ - return math.isclose(self.angle, other.angle, rel_tol=rel_tol, abs_tol=abs_tol) and self.plane == other.plane @dataclass(frozen=True) diff --git a/graphix/ops.py b/graphix/ops.py index 538c65dc..ae9b93e7 100644 --- a/graphix/ops.py +++ b/graphix/ops.py @@ -4,16 +4,69 @@ from functools import reduce from itertools import product +from typing import ClassVar import numpy as np import numpy.typing as npt -from graphix._db import WellKnownMatrix +from graphix import utils -class Ops(WellKnownMatrix): +class Ops: """Basic single- and two-qubits operators.""" + I: ClassVar[npt.NDArray[np.complex128]] = utils.lock(np.asarray([[1, 0], [0, 1]])) + X: ClassVar[npt.NDArray[np.complex128]] = utils.lock(np.asarray([[0, 1], [1, 0]])) + Y: ClassVar[npt.NDArray[np.complex128]] = utils.lock(np.asarray([[0, -1j], [1j, 0]])) + Z: ClassVar[npt.NDArray[np.complex128]] = utils.lock(np.asarray([[1, 0], [0, -1]])) + S: ClassVar[npt.NDArray[np.complex128]] = utils.lock(np.asarray([[1, 0], [0, 1j]])) + SDG: ClassVar[npt.NDArray[np.complex128]] = utils.lock(np.asarray([[1, 0], [0, -1j]])) + H: ClassVar[npt.NDArray[np.complex128]] = utils.lock(np.asarray([[1, 1], [1, -1]]) / np.sqrt(2)) + CZ: ClassVar[npt.NDArray[np.complex128]] = utils.lock( + np.asarray( + [ + [1, 0, 0, 0], + [0, 1, 0, 0], + [0, 0, 1, 0], + [0, 0, 0, -1], + ], + ) + ) + CNOT: ClassVar[npt.NDArray[np.complex128]] = utils.lock( + np.asarray( + [ + [1, 0, 0, 0], + [0, 1, 0, 0], + [0, 0, 0, 1], + [0, 0, 1, 0], + ], + ) + ) + SWAP: ClassVar[npt.NDArray[np.complex128]] = utils.lock( + np.asarray( + [ + [1, 0, 0, 0], + [0, 0, 1, 0], + [0, 1, 0, 0], + [0, 0, 0, 1], + ], + ) + ) + CCX: ClassVar[npt.NDArray[np.complex128]] = utils.lock( + np.asarray( + [ + [1, 0, 0, 0, 0, 0, 0, 0], + [0, 1, 0, 0, 0, 0, 0, 0], + [0, 0, 1, 0, 0, 0, 0, 0], + [0, 0, 0, 1, 0, 0, 0, 0], + [0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 0, 0, 1, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 1], + [0, 0, 0, 0, 0, 0, 1, 0], + ], + ) + ) + @staticmethod def rx(theta: float) -> npt.NDArray[np.complex128]: """X rotation. diff --git a/graphix/pattern.py b/graphix/pattern.py index 45cced75..652e16e6 100644 --- a/graphix/pattern.py +++ b/graphix/pattern.py @@ -15,12 +15,13 @@ import typing_extensions from graphix import command -from graphix.clifford import Clifford, Domains +from graphix.clifford import Clifford from graphix.command import Command, CommandKind from graphix.device_interface import PatternRunner +from graphix.fundamentals import Axis, Plane, Sign from graphix.gflow import find_flow, find_gflow, get_layers from graphix.graphsim.graphstate import GraphState -from graphix.pauli import Axis, PauliMeasurement, Plane, Sign +from graphix.measurements import Domains, PauliMeasurement from graphix.simulator import PatternSimulator from graphix.states import BasicStates from graphix.visualization import GraphVisualizer @@ -2130,7 +2131,7 @@ def measure_pauli(pattern, leave_input, copy=False, use_rustworkx=False): measure = graph_state.measure_z else: typing_extensions.assert_never(basis.axis) - if basis.sign == Sign.Plus: + if basis.sign == Sign.PLUS: results[pattern_cmd.node] = measure(pattern_cmd.node, choice=0) else: results[pattern_cmd.node] = 1 - measure(pattern_cmd.node, choice=1) diff --git a/graphix/pauli.py b/graphix/pauli.py index 8e179c84..11af8915 100644 --- a/graphix/pauli.py +++ b/graphix/pauli.py @@ -2,298 +2,32 @@ from __future__ import annotations -import enum -import sys -import typing -from numbers import Number -from typing import TYPE_CHECKING +import dataclasses +from typing import TYPE_CHECKING, ClassVar -import numpy as np -import numpy.typing as npt import typing_extensions -from graphix._db import CLIFFORD +from graphix.fundamentals import IXYZ, Axis, ComplexUnit, SupportsComplexCtor from graphix.ops import Ops +from graphix.states import BasicStates if TYPE_CHECKING: - from graphix.states import PlanarState - - -class IXYZ(enum.Enum): - """I, X, Y or Z.""" - - I = -1 - X = 0 - Y = 1 - Z = 2 - - -class Sign(enum.Enum): - """Sign, plus or minus.""" - - Plus = 1 - Minus = -1 - - def __str__(self) -> str: - """Return `+` or `-`.""" - if self == Sign.Plus: - return "+" - return "-" - - @staticmethod - def plus_if(b: bool) -> Sign: - """Return `+` if `b` is `True`, `-` otherwise.""" - if b: - return Sign.Plus - return Sign.Minus - - @staticmethod - def minus_if(b: bool) -> Sign: - """Return `-` if `b` is `True`, `+` otherwise.""" - if b: - return Sign.Minus - return Sign.Plus - - def __neg__(self) -> Sign: - """Swap the sign.""" - return Sign.minus_if(self == Sign.Plus) - - def __mul__(self, other: SignOrNumber) -> SignOrNumber: - """Multiply the sign with another sign or a number.""" - if isinstance(other, Sign): - return Sign.plus_if(self == other) - if isinstance(other, Number): - return self.value * other - return NotImplemented - - def __rmul__(self, other) -> Number: - """Multiply the sign with a number.""" - if isinstance(other, Number): - return self.value * other - return NotImplemented - - def __int__(self) -> int: - """Return `1` for `+` and `-1` for `-`.""" - return self.value - - def __float__(self) -> float: - """Return `1.0` for `+` and `-1.0` for `-`.""" - return float(self.value) - - def __complex__(self) -> complex: - """Return `1.0 + 0j` for `+` and `-1.0 + 0j` for `-`.""" - return complex(self.value) - - -if sys.version_info >= (3, 10): - SignOrNumber = typing.TypeVar("SignOrNumber", bound=Sign | Number) -else: - SignOrNumber = typing.TypeVar("SignOrNumber", bound=typing.Union[Sign, Number]) - - -class ComplexUnit: - """ - Complex unit: 1, -1, j, -j. - - Complex units can be multiplied with other complex units, - with Python constants 1, -1, 1j, -1j, and can be negated. - """ - - def __init__(self, sign: Sign, is_imag: bool): - self.__sign = sign - self.__is_imag = is_imag - - @property - def sign(self) -> Sign: - """Return the sign.""" - return self.__sign - - @property - def is_imag(self) -> bool: - """Return `True` if `j` or `-j`.""" - return self.__is_imag - - def __complex__(self) -> complex: - """Return the unit as complex number.""" - result: complex = complex(self.__sign) - if self.__is_imag: - result *= 1j - return result - - def __repr__(self) -> str: - """Return a string representation of the unit.""" - if self.__is_imag: - result = "1j" - else: - result = "1" - if self.__sign == Sign.Minus: - result = "-" + result - return result - - def prefix(self, s: str) -> str: - """Prefix the given string by the complex unit as coefficient, 1 leaving the string unchanged.""" - if self.__is_imag: - result = "1j*" + s - else: - result = s - if self.__sign == Sign.Minus: - result = "-" + result - return result - - def __mul__(self, other: ComplexUnit) -> ComplexUnit: - """Multiply the complex unit with another complex unit.""" - if isinstance(other, ComplexUnit): - is_imag = self.__is_imag != other.__is_imag - sign = self.__sign * other.__sign * Sign.minus_if(self.__is_imag and other.__is_imag) - return COMPLEX_UNITS[sign == Sign.Minus][is_imag] - return NotImplemented - - def __rmul__(self, other): - """Multiply the complex unit with a number.""" - if other == 1: - return self - elif other == -1: - return COMPLEX_UNITS[self.__sign == Sign.Plus][self.__is_imag] - elif other == 1j: - return COMPLEX_UNITS[self.__sign == Sign.plus_if(self.__is_imag)][not self.__is_imag] - elif other == -1j: - return COMPLEX_UNITS[self.__sign == Sign.minus_if(self.__is_imag)][not self.__is_imag] - - def __neg__(self): - """Return the opposite of the complex unit.""" - return COMPLEX_UNITS[self.__sign == Sign.Plus][self.__is_imag] - - -COMPLEX_UNITS = tuple( - tuple(ComplexUnit(sign, is_imag) for is_imag in (False, True)) for sign in (Sign.Plus, Sign.Minus) -) - - -UNIT = COMPLEX_UNITS[False][False] - - -UNITS = (UNIT, -UNIT, 1j * UNIT, -1j * UNIT) - + from collections.abc import Iterator -class Axis(enum.Enum): - """Axis: `X`, `Y` or `Z`.""" + import numpy as np + import numpy.typing as npt - X = 0 - Y = 1 - Z = 2 - - @property - def op(self) -> npt.NDArray: - """Return the single qubit operator associated to the axis.""" - if self == Axis.X: - return Ops.X - if self == Axis.Y: - return Ops.Y - if self == Axis.Z: - return Ops.Z - - typing_extensions.assert_never(self) - - -class Plane(enum.Enum): - """Plane: `XY`, `YZ` or `XZ`.""" - - XY = 0 - YZ = 1 - XZ = 2 - - @property - def axes(self) -> list[Axis]: - """Return the pair of axes that carry the plane.""" - # match self: - # case Plane.XY: - # return [Axis.X, Axis.Y] - # case Plane.YZ: - # return [Axis.Y, Axis.Z] - # case Plane.XZ: - # return [Axis.X, Axis.Z] - if self == Plane.XY: - return [Axis.X, Axis.Y] - elif self == Plane.YZ: - return [Axis.Y, Axis.Z] - elif self == Plane.XZ: - return [Axis.X, Axis.Z] - - @property - def orth(self) -> Axis: - """Return the axis orthogonal to the plane.""" - if self == Plane.XY: - return Axis.Z - if self == Plane.YZ: - return Axis.X - if self == Plane.XZ: - return Axis.Y - typing_extensions.assert_never(self) - - @property - def cos(self) -> Axis: - """Return the axis of the plane that conventionally carries the cos.""" - # match self: - # case Plane.XY: - # return Axis.X - # case Plane.YZ: - # return Axis.Z # former convention was Y - # case Plane.XZ: - # return Axis.Z # former convention was X - if self == Plane.XY: - return Axis.X - elif self == Plane.YZ: - return Axis.Z # former convention was Y - elif self == Plane.XZ: - return Axis.Z # former convention was X - - @property - def sin(self) -> Axis: - """Return the axis of the plane that conventionally carries the sin.""" - # match self: - # case Plane.XY: - # return Axis.Y - # case Plane.YZ: - # return Axis.Y # former convention was Z - # case Plane.XZ: - # return Axis.X # former convention was Z - if self == Plane.XY: - return Axis.Y - elif self == Plane.YZ: - return Axis.Y # former convention was Z - elif self == Plane.XZ: - return Axis.X # former convention was Z + from graphix.states import PlanarState - def polar(self, angle: float) -> tuple[float, float, float]: - """Return the Cartesian coordinates of the point of module 1 at the given angle, following the conventional orientation for cos and sin.""" - result = [0, 0, 0] - result[self.cos.value] = np.cos(angle) - result[self.sin.value] = np.sin(angle) - return tuple(result) - @staticmethod - def from_axes(a: Axis, b: Axis) -> Plane: - """Return the plane carried by the given axes.""" - if b.value < a.value: - a, b = b, a - # match a, b: - # case Axis.X, Axis.Y: - # return Plane.XY - # case Axis.Y, Axis.Z: - # return Plane.YZ - # case Axis.X, Axis.Z: - # return Plane.XZ - if a == Axis.X and b == Axis.Y: - return Plane.XY - elif a == Axis.Y and b == Axis.Z: - return Plane.YZ - elif a == Axis.X and b == Axis.Z: - return Plane.XZ - assert a == b - raise ValueError(f"Cannot make a plane giving the same axis {a} twice.") +class _PauliMeta(type): + def __iter__(cls) -> Iterator[Pauli]: + """Iterate over all Pauli gates, including the unit.""" + return Pauli.iterate() -class Pauli: +@dataclasses.dataclass(frozen=True) +class Pauli(metaclass=_PauliMeta): """Pauli gate: `u * {I, X, Y, Z}` where u is a complex unit. Pauli gates can be multiplied with other Pauli gates (with `@`), @@ -301,14 +35,17 @@ class Pauli: and can be negated. """ - def __init__(self, symbol: IXYZ, unit: ComplexUnit): - self.__symbol = symbol - self.__unit = unit + symbol: IXYZ = IXYZ.I + unit: ComplexUnit = ComplexUnit.ONE + I: ClassVar[Pauli] + X: ClassVar[Pauli] + Y: ClassVar[Pauli] + Z: ClassVar[Pauli] @staticmethod def from_axis(axis: Axis) -> Pauli: """Return the Pauli associated to the given axis.""" - return Pauli(IXYZ[axis.name], UNIT) + return Pauli(IXYZ[axis.name]) @property def axis(self) -> Axis: @@ -316,127 +53,119 @@ def axis(self) -> Axis: Fails if the Pauli is identity. """ - if self.__symbol == IXYZ.I: + if self.symbol == IXYZ.I: raise ValueError("I is not an axis.") - return Axis[self.__symbol.name] - - @property - def symbol(self) -> IXYZ: - """Return the symbol (without the complex unit).""" - return self.__symbol - - @property - def unit(self) -> ComplexUnit: - """Return the complex unit.""" - return self.__unit + return Axis[self.symbol.name] @property - def matrix(self) -> npt.NDArray: + def matrix(self) -> npt.NDArray[np.complex128]: """Return the matrix of the Pauli gate.""" - return complex(self.__unit) * CLIFFORD[self.__symbol.value + 1] + co = complex(self.unit) + if self.symbol == IXYZ.I: + return co * Ops.I + if self.symbol == IXYZ.X: + return co * Ops.X + if self.symbol == IXYZ.Y: + return co * Ops.Y + if self.symbol == IXYZ.Z: + return co * Ops.Z + typing_extensions.assert_never(self.symbol) - def get_eigenstate(self, eigenvalue=0) -> PlanarState: + def eigenstate(self, binary: int = 0) -> PlanarState: """Return the eigenstate of the Pauli.""" - from graphix.states import BasicStates - + if binary not in {0, 1}: + raise ValueError("b must be 0 or 1.") if self.symbol == IXYZ.X: - return BasicStates.PLUS if eigenvalue == 0 else BasicStates.MINUS + return BasicStates.PLUS if binary == 0 else BasicStates.MINUS if self.symbol == IXYZ.Y: - return BasicStates.PLUS_I if eigenvalue == 0 else BasicStates.MINUS_I + return BasicStates.PLUS_I if binary == 0 else BasicStates.MINUS_I if self.symbol == IXYZ.Z: - return BasicStates.ZERO if eigenvalue == 0 else BasicStates.ONE + return BasicStates.ZERO if binary == 0 else BasicStates.ONE # Any state is eigenstate of the identity if self.symbol == IXYZ.I: return BasicStates.PLUS typing_extensions.assert_never(self.symbol) + def _repr_impl(self, prefix: str | None) -> str: + sym = self.symbol.name + if prefix is not None: + sym = f"{prefix}.{sym}" + if self.unit == ComplexUnit.ONE: + return sym + if self.unit == ComplexUnit.MINUS_ONE: + return f"-{sym}" + if self.unit == ComplexUnit.J: + return f"1j * {sym}" + if self.unit == ComplexUnit.MINUS_J: + return f"-1j * {sym}" + typing_extensions.assert_never(self.unit) + def __repr__(self) -> str: - """Return a fully qualified string representation of the Pauli.""" - return self.__unit.prefix(f"graphix.pauli.{self.__symbol.name}") + """Return a string representation of the Pauli.""" + return self._repr_impl(self.__class__.__name__) def __str__(self) -> str: - """Return a string representation of the Pauli (without module prefix).""" - return self.__unit.prefix(self.__symbol.name) + """Return a simplified string representation of the Pauli.""" + return self._repr_impl(None) + + @staticmethod + def _matmul_impl(lhs: IXYZ, rhs: IXYZ) -> Pauli: + if lhs == IXYZ.I: + return Pauli(rhs) + if rhs == IXYZ.I: + return Pauli(lhs) + if lhs == rhs: + return Pauli() + lr = (lhs, rhs) + if lr == (IXYZ.X, IXYZ.Y): + return Pauli(IXYZ.Z, ComplexUnit.J) + if lr == (IXYZ.Y, IXYZ.X): + return Pauli(IXYZ.Z, ComplexUnit.MINUS_J) + if lr == (IXYZ.Y, IXYZ.Z): + return Pauli(IXYZ.X, ComplexUnit.J) + if lr == (IXYZ.Z, IXYZ.Y): + return Pauli(IXYZ.X, ComplexUnit.MINUS_J) + if lr == (IXYZ.Z, IXYZ.X): + return Pauli(IXYZ.Y, ComplexUnit.J) + if lr == (IXYZ.X, IXYZ.Z): + return Pauli(IXYZ.Y, ComplexUnit.MINUS_J) + raise RuntimeError("Unreachable.") # pragma: no cover def __matmul__(self, other: Pauli) -> Pauli: """Return the product of two Paulis.""" if isinstance(other, Pauli): - if self.__symbol == IXYZ.I: - symbol = other.__symbol - unit = 1 - elif other.__symbol == IXYZ.I: - symbol = self.__symbol - unit = 1 - elif self.__symbol == other.__symbol: - symbol = IXYZ.I - unit = 1 - elif (self.__symbol.value + 1) % 3 == other.__symbol.value: - symbol = IXYZ((self.__symbol.value + 2) % 3) - unit = 1j - else: - symbol = IXYZ((self.__symbol.value + 1) % 3) - unit = -1j - return get(symbol, unit * self.__unit * other.__unit) + return self._matmul_impl(self.symbol, other.symbol) * (self.unit * other.unit) return NotImplemented - def __rmul__(self, other: ComplexUnit) -> Pauli: + def __mul__(self, other: ComplexUnit | SupportsComplexCtor) -> Pauli: """Return the product of two Paulis.""" - if isinstance(other, ComplexUnit): - return get(self.__symbol, other * self.__unit) + if u := ComplexUnit.try_from(other): + return dataclasses.replace(self, unit=self.unit * u) return NotImplemented + def __rmul__(self, other: ComplexUnit | SupportsComplexCtor) -> Pauli: + """Return the product of two Paulis.""" + return self.__mul__(other) + def __neg__(self) -> Pauli: """Return the opposite.""" - return get(self.__symbol, -self.__unit) - - -TABLE = tuple( - tuple(tuple(Pauli(symbol, COMPLEX_UNITS[sign][is_imag]) for is_imag in (False, True)) for sign in (False, True)) - for symbol in (IXYZ.I, IXYZ.X, IXYZ.Y, IXYZ.Z) -) - - -LIST = tuple(pauli for sign_im_list in TABLE for im_list in sign_im_list for pauli in im_list) - - -def get(symbol: IXYZ, unit: ComplexUnit) -> Pauli: - """Return the Pauli gate with given symbol and unit.""" - return TABLE[symbol.value + 1][unit.sign == Sign.Minus][unit.is_imag] - - -# TODO: Include in Pauli namespace -I = get(IXYZ.I, UNIT) -X = get(IXYZ.X, UNIT) -Y = get(IXYZ.Y, UNIT) -Z = get(IXYZ.Z, UNIT) - - -def parse(name: str) -> Pauli: - """Return the Pauli gate with the given name (limited to "I", "X", "Y" and "Z").""" - return get(IXYZ[name], UNIT) - - -def is_int(value: Number) -> bool: - """Return `True` if `value` is an integer, `False` otherwise.""" - return value == int(value) + return dataclasses.replace(self, unit=-self.unit) + @staticmethod + def iterate(symbol_only: bool = False) -> Iterator[Pauli]: + """Iterate over all Pauli gates. -class PauliMeasurement(typing.NamedTuple): - """Pauli measurement.""" + Parameters + ---------- + symbol_only (bool, optional): Exclude the unit in the iteration. Defaults to False. + """ + us = (ComplexUnit.ONE,) if symbol_only else tuple(ComplexUnit) + for symbol in IXYZ: + for unit in us: + yield Pauli(symbol, unit) - axis: Axis - sign: Sign - @staticmethod - def try_from(plane: Plane, angle: float) -> PauliMeasurement | None: - """Return the Pauli measurement description if a given measure is Pauli.""" - angle_double = 2 * angle - if not is_int(angle_double): - return None - angle_double_mod_4 = int(angle_double) % 4 - if angle_double_mod_4 % 2 == 0: - axis = plane.cos - else: - axis = plane.sin - sign = Sign.minus_if(angle_double_mod_4 >= 2) - return PauliMeasurement(axis, sign) +Pauli.I = Pauli(IXYZ.I) +Pauli.X = Pauli(IXYZ.X) +Pauli.Y = Pauli(IXYZ.Y) +Pauli.Z = Pauli(IXYZ.Z) diff --git a/graphix/pyzx.py b/graphix/pyzx.py index 281640ea..2925780f 100644 --- a/graphix/pyzx.py +++ b/graphix/pyzx.py @@ -15,8 +15,9 @@ from pyzx.graph import Graph from pyzx.utils import EdgeType, FractionLike, VertexType -from graphix.opengraph import Measurement, OpenGraph -from graphix.pauli import Plane +from graphix.fundamentals import Plane +from graphix.measurements import Measurement +from graphix.opengraph import OpenGraph if TYPE_CHECKING: from pyzx.graph.base import BaseGraph diff --git a/graphix/sim/base_backend.py b/graphix/sim/base_backend.py index 544c8f3c..2b7af5f4 100644 --- a/graphix/sim/base_backend.py +++ b/graphix/sim/base_backend.py @@ -2,7 +2,6 @@ from __future__ import annotations -from dataclasses import dataclass from typing import TYPE_CHECKING import numpy as np @@ -18,15 +17,8 @@ from numpy.random import Generator - from graphix.pauli import Plane - - -@dataclass -class MeasurementDescription: - """An MBQC measurement.""" - - plane: Plane - angle: float + from graphix.fundamentals import Plane + from graphix.measurements import Measurement class NodeIndex: @@ -195,18 +187,16 @@ def entangle_nodes(self, edge: tuple[int, int]) -> None: control = self.node_index.index(edge[1]) self.state.entangle((target, control)) - def measure(self, node: int, measurement_description: MeasurementDescription) -> bool: + def measure(self, node: int, measurement: Measurement) -> bool: """Perform measurement of a node and trace out the qubit. Parameters ---------- node: int - measurement_description: MeasurementDescription + measurement: Measurement """ loc = self.node_index.index(node) - result = perform_measure( - loc, measurement_description.plane, measurement_description.angle, self.state, self.__rng, self.__pr_calc - ) + result = perform_measure(loc, measurement.plane, measurement.angle, self.state, self.__rng, self.__pr_calc) self.node_index.remove(node) self.state.remove_qubit(loc) return result diff --git a/graphix/sim/statevec.py b/graphix/sim/statevec.py index 6548fa29..bfc5ec6a 100644 --- a/graphix/sim/statevec.py +++ b/graphix/sim/statevec.py @@ -11,7 +11,7 @@ import numpy as np import numpy.typing as npt -from graphix import states, type_utils +from graphix import states, utils from graphix.sim.base_backend import Backend, State from graphix.states import BasicStates @@ -111,7 +111,7 @@ def __init__( else: if isinstance(input_list[0], states.State): - type_utils.check_list_elements(input_list, states.State) + utils.check_list_elements(input_list, states.State) if nqubit is None: nqubit = len(input_list) elif nqubit != len(input_list): @@ -121,7 +121,7 @@ def __init__( # reshape self.psi = tmp_psi.reshape((2,) * nqubit) elif isinstance(input_list[0], numbers.Number): - type_utils.check_list_elements(input_list, numbers.Number) + utils.check_list_elements(input_list, numbers.Number) if nqubit is None: length = len(input_list) if length & (length - 1): diff --git a/graphix/sim/tensornet.py b/graphix/sim/tensornet.py index 90283a1b..8417bf09 100644 --- a/graphix/sim/tensornet.py +++ b/graphix/sim/tensornet.py @@ -13,13 +13,14 @@ from graphix import command from graphix.ops import Ops from graphix.rng import ensure_rng -from graphix.sim.base_backend import Backend, MeasurementDescription, State +from graphix.sim.base_backend import Backend, State from graphix.states import BasicStates, PlanarState if TYPE_CHECKING: from numpy.random import Generator from graphix.clifford import Clifford + from graphix.measurements import Measurement from graphix.simulator import MeasureMethod @@ -143,7 +144,7 @@ def entangle_nodes(self, edge) -> None: elif self.graph_prep == "opt": pass - def measure(self, node: int, measurement_description: MeasurementDescription) -> tuple[Backend, int]: + def measure(self, node: int, measurement: Measurement) -> tuple[Backend, int]: """Perform measurement of the node. In the context of tensornetwork, performing measurement equals to @@ -153,7 +154,7 @@ def measure(self, node: int, measurement_description: MeasurementDescription) -> ---------- node : int index of the node to measure - measurement_description : MeasurementDescription + measurement : Measurement measure plane and angle """ if node in self._isolated_nodes: @@ -168,9 +169,9 @@ def measure(self, node: int, measurement_description: MeasurementDescription) -> result = self.__rng.choice([0, 1]) self.results[node] = result buffer = 2**0.5 - vec = PlanarState(measurement_description.plane, measurement_description.angle).get_statevector() + vec = PlanarState(measurement.plane, measurement.angle).get_statevector() if result: - vec = measurement_description.plane.orth.op @ vec + vec = measurement.plane.orth.matrix @ vec proj_vec = vec * buffer self.state.measure_single(node, basis=proj_vec) return result diff --git a/graphix/simulator.py b/graphix/simulator.py index 412028c5..941c676f 100644 --- a/graphix/simulator.py +++ b/graphix/simulator.py @@ -14,7 +14,8 @@ from graphix.clifford import Clifford from graphix.command import BaseM, CommandKind, M, MeasureUpdate -from graphix.sim.base_backend import Backend, MeasurementDescription +from graphix.measurements import Measurement +from graphix.sim.base_backend import Backend from graphix.sim.density_matrix import DensityMatrixBackend from graphix.sim.statevec import StatevectorBackend from graphix.sim.tensornet import TensorNetworkBackend @@ -41,7 +42,7 @@ def measure(self, backend: Backend, cmd, noise_model=None) -> None: self.set_measure_result(cmd.node, result) @abc.abstractmethod - def get_measurement_description(self, cmd: BaseM) -> MeasurementDescription: + def get_measurement_description(self, cmd: BaseM) -> Measurement: """Return the description of the measurement performed by a given measure command (possibly blind).""" ... @@ -64,7 +65,7 @@ def __init__(self, results=None): results = dict() self.results = results - def get_measurement_description(self, cmd: BaseM) -> MeasurementDescription: + def get_measurement_description(self, cmd: BaseM) -> Measurement: """Return the description of the measurement performed by a given measure command (cannot be blind in the case of DefaultMeasureMethod).""" assert isinstance(cmd, M) angle = cmd.angle * np.pi @@ -73,7 +74,7 @@ def get_measurement_description(self, cmd: BaseM) -> MeasurementDescription: t_signal = sum(self.results[j] for j in cmd.t_domain) measure_update = MeasureUpdate.compute(cmd.plane, s_signal % 2 == 1, t_signal % 2 == 1, Clifford.I) angle = angle * measure_update.coeff + measure_update.add_term - return MeasurementDescription(measure_update.new_plane, angle) + return Measurement(angle, measure_update.new_plane) def get_measure_result(self, node: int) -> bool: """Return the result of a previous measurement.""" diff --git a/graphix/states.py b/graphix/states.py index 7d2480cc..af639a74 100644 --- a/graphix/states.py +++ b/graphix/states.py @@ -11,7 +11,7 @@ import pydantic.dataclasses import typing_extensions -from graphix.pauli import Plane +from graphix.fundamentals import Plane # generic class State for all States @@ -80,12 +80,12 @@ def get_statevector(self) -> npt.NDArray[np.complex128]: class BasicStates: """Basic states.""" - ZERO: ClassVar = PlanarState(Plane.XZ, 0) - ONE: ClassVar = PlanarState(Plane.XZ, np.pi) - PLUS: ClassVar = PlanarState(Plane.XY, 0) - MINUS: ClassVar = PlanarState(Plane.XY, np.pi) - PLUS_I: ClassVar = PlanarState(Plane.XY, np.pi / 2) - MINUS_I: ClassVar = PlanarState(Plane.XY, -np.pi / 2) + ZERO: ClassVar[PlanarState] = PlanarState(Plane.XZ, 0) + ONE: ClassVar[PlanarState] = PlanarState(Plane.XZ, np.pi) + PLUS: ClassVar[PlanarState] = PlanarState(Plane.XY, 0) + MINUS: ClassVar[PlanarState] = PlanarState(Plane.XY, np.pi) + PLUS_I: ClassVar[PlanarState] = PlanarState(Plane.XY, np.pi / 2) + MINUS_I: ClassVar[PlanarState] = PlanarState(Plane.XY, -np.pi / 2) # remove that in the end # need in TN backend - VEC: ClassVar = [PLUS, MINUS, ZERO, ONE, PLUS_I, MINUS_I] + VEC: ClassVar[list[PlanarState]] = [PLUS, MINUS, ZERO, ONE, PLUS_I, MINUS_I] diff --git a/graphix/transpiler.py b/graphix/transpiler.py index 5abfd8ce..e4d1e2d9 100644 --- a/graphix/transpiler.py +++ b/graphix/transpiler.py @@ -16,9 +16,9 @@ from graphix import command, instruction from graphix.clifford import Clifford from graphix.command import CommandKind, E, M, N, X, Z +from graphix.fundamentals import Plane from graphix.ops import Ops from graphix.pattern import Pattern -from graphix.pauli import Plane from graphix.sim import base_backend from graphix.sim.statevec import Data, Statevec diff --git a/graphix/type_utils.py b/graphix/utils.py similarity index 59% rename from graphix/type_utils.py rename to graphix/utils.py index e1fbc0cc..5356bcaf 100644 --- a/graphix/type_utils.py +++ b/graphix/utils.py @@ -1,10 +1,13 @@ -"""Type utilities.""" +"""Utilities.""" from __future__ import annotations import sys import typing -from typing import TYPE_CHECKING, Any, ClassVar, Literal, TypeVar +from typing import TYPE_CHECKING, Any, ClassVar, Literal, SupportsInt, TypeVar + +import numpy as np +import numpy.typing as npt if TYPE_CHECKING: from collections.abc import Iterable @@ -41,3 +44,31 @@ def check_kind(cls: type, scope: dict[str, Any]) -> None: if typing.get_origin(ann) is not Literal: msg = "Tag attribute must be a literal." raise TypeError(msg) + + +def is_integer(value: SupportsInt) -> bool: + """Return `True` if `value` is an integer, `False` otherwise.""" + return value == int(value) + + +G = TypeVar("G", bound=np.generic) + + +@typing.overload +def lock(data: npt.NDArray[Any]) -> npt.NDArray[np.complex128]: ... + + +@typing.overload +def lock(data: npt.NDArray[Any], dtype: type[G]) -> npt.NDArray[G]: ... + + +def lock(data: npt.NDArray[Any], dtype: type = np.complex128) -> npt.NDArray[Any]: + """Create a true immutable view. + + data must not have aliasing references, otherwise users can still turn on writeable flag of m. + """ + m: npt.NDArray[Any] = data.astype(dtype) + m.flags.writeable = False + v = m.view() + assert not v.flags.writeable + return v diff --git a/graphix/visualization.py b/graphix/visualization.py index 79874396..5bbf3b33 100644 --- a/graphix/visualization.py +++ b/graphix/visualization.py @@ -11,7 +11,7 @@ from matplotlib import pyplot as plt from graphix import gflow -from graphix.pauli import Plane +from graphix.fundamentals import Plane if TYPE_CHECKING: # MEMO: Potential circular import diff --git a/pyproject.toml b/pyproject.toml index 397e99c7..2f77d0bd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -101,22 +101,27 @@ filterwarnings = ["ignore:Couldn't import `kahypar`"] [tool.mypy] # Keep in sync with pyright files = [ - "**/__init__.py", "graphix/channels.py", "graphix/clifford.py", "graphix/command.py", + "graphix/fundamentals.py", "graphix/instruction.py", "graphix/linalg_validations.py", + "graphix/measurements.py", "graphix/ops.py", + "graphix/pauli.py", "graphix/pyzx.py", "graphix/rng.py", "graphix/states.py", - "graphix/type_utils.py", + "graphix/utils.py", + "graphix/**/__init__.py", "graphix/_db.py", "noxfile.py", "tests/conftest.py", "tests/test_clifford.py", + "tests/test_command.py", "tests/test_db.py", + "tests/test_fundamentals.py", "tests/test_kraus.py", "tests/test_pauli.py", "tests/test_pyzx.py", @@ -128,22 +133,27 @@ strict = true [tool.pyright] # Keep in sync with mypy include = [ - "**/__init__.py", "graphix/channels.py", "graphix/clifford.py", "graphix/command.py", + "graphix/fundamentals.py", "graphix/instruction.py", "graphix/linalg_validations.py", + "graphix/measurements.py", "graphix/ops.py", + "graphix/pauli.py", "graphix/pyzx.py", "graphix/rng.py", "graphix/states.py", - "graphix/type_utils.py", + "graphix/utils.py", + "graphix/**/__init__.py", "graphix/_db.py", "noxfile.py", "tests/conftest.py", "tests/test_clifford.py", + "tests/test_command.py", "tests/test_db.py", + "tests/test_fundamentals.py", "tests/test_kraus.py", "tests/test_pauli.py", "tests/test_pyzx.py", diff --git a/tests/test_clifford.py b/tests/test_clifford.py index de57d0ff..d89c52ea 100644 --- a/tests/test_clifford.py +++ b/tests/test_clifford.py @@ -1,15 +1,22 @@ from __future__ import annotations +import cmath import functools import itertools +import math import operator -from typing import Final +import re +from typing import TYPE_CHECKING, Final import numpy as np import pytest from graphix.clifford import Clifford -from graphix.pauli import IXYZ, ComplexUnit, Pauli, Sign +from graphix.fundamentals import IXYZ, ComplexUnit, Sign +from graphix.pauli import Pauli + +if TYPE_CHECKING: + from numpy.random import Generator _QASM3_DB: Final = { "id": Clifford.I, @@ -38,7 +45,10 @@ def test_iteration(self) -> None: @pytest.mark.parametrize("c", Clifford) def test_repr(self, c: Clifford) -> None: - for term in repr(c).split(" @ "): + rep: str = repr(c) + m = re.match(r"\((.*)\)", rep) + rep = m.group(1) if m is not None else rep + for term in rep.split(" @ "): assert term in [ "Clifford.I", "Clifford.H", @@ -54,10 +64,10 @@ def test_repr(self, c: Clifford) -> None: Pauli(sym, u) for sym in IXYZ for u in ( - ComplexUnit(Sign.Plus, False), - ComplexUnit(Sign.Minus, False), - ComplexUnit(Sign.Plus, True), - ComplexUnit(Sign.Minus, True), + ComplexUnit.from_properties(sign=Sign.PLUS, is_imag=False), + ComplexUnit.from_properties(sign=Sign.MINUS, is_imag=False), + ComplexUnit.from_properties(sign=Sign.PLUS, is_imag=True), + ComplexUnit.from_properties(sign=Sign.MINUS, is_imag=True), ) ), ), @@ -75,3 +85,12 @@ def test_measure(self, c: Clifford, p: Pauli) -> None: def test_qasm3(self, c: Clifford) -> None: cmul: Clifford = functools.reduce(operator.matmul, (_QASM3_DB[term] for term in reversed(c.qasm3))) assert cmul == c + + @pytest.mark.parametrize("c", Clifford) + def test_try_from_matrix(self, fx_rng: Generator, c: Clifford) -> None: + co = cmath.exp(2j * math.pi * fx_rng.uniform()) + assert Clifford.try_from_matrix(co * c.matrix) == c + + def test_try_from_matrix_ng(self, fx_rng: Generator) -> None: + assert Clifford.try_from_matrix(np.zeros((2, 3))) is None + assert Clifford.try_from_matrix(fx_rng.normal(size=(2, 2))) is None diff --git a/tests/test_command.py b/tests/test_command.py new file mode 100644 index 00000000..f3993d5e --- /dev/null +++ b/tests/test_command.py @@ -0,0 +1,51 @@ +from __future__ import annotations + +import itertools +import math + +import numpy as np +import pytest + +from graphix.clifford import Clifford +from graphix.command import MeasureUpdate +from graphix.fundamentals import Plane + + +@pytest.mark.parametrize( + ("plane", "s", "t", "clifford", "angle", "choice"), + itertools.product( + Plane, + (False, True), + (False, True), + Clifford, + (0, math.pi), + (False, True), + ), +) +def test_measure_update( + plane: Plane, + s: bool, + t: bool, + clifford: Clifford, + angle: float, + choice: bool, +) -> None: + measure_update = MeasureUpdate.compute(plane, s, t, clifford) + new_angle = angle * measure_update.coeff + measure_update.add_term + vec = measure_update.new_plane.polar(new_angle) + op_mat = np.eye(2, dtype=np.complex128) / 2 + for i in range(3): + op_mat += (-1) ** (choice) * vec[i] * Clifford(i + 1).matrix / 2 + + if s: + clifford = Clifford.X @ clifford + if t: + clifford = Clifford.Z @ clifford + vec = plane.polar(angle) + op_mat_ref = np.eye(2, dtype=np.complex128) / 2 + for i in range(3): + op_mat_ref += (-1) ** (choice) * vec[i] * Clifford(i + 1).matrix / 2 + clifford_mat = clifford.matrix + op_mat_ref = clifford_mat.conj().T @ op_mat_ref @ clifford_mat + + assert np.allclose(op_mat, op_mat_ref) or np.allclose(op_mat, -op_mat_ref) diff --git a/tests/test_db.py b/tests/test_db.py index 3c496dde..9ea3df69 100644 --- a/tests/test_db.py +++ b/tests/test_db.py @@ -3,7 +3,6 @@ import itertools import numpy as np -import numpy.typing as npt import pytest from graphix._db import ( @@ -12,90 +11,36 @@ CLIFFORD_HSZ_DECOMPOSITION, CLIFFORD_MEASURE, CLIFFORD_MUL, - _CliffordMeasure, ) +from graphix.clifford import Clifford -class TestClifford: - @staticmethod - def classify_pauli(arr: npt.NDArray[np.complex128]) -> _CliffordMeasure: - """Compare the gate arr with Pauli gates and return the tuple of (Pauli string, sign). - - Parameters - ---------- - arr: np.array - 2x2 matrix. - - Returns - ------- - ind : _CliffordMeasure - """ - if np.allclose(CLIFFORD[1], arr): - return _CliffordMeasure("X", +1) - if np.allclose(-1 * CLIFFORD[1], arr): - return _CliffordMeasure("X", -1) - if np.allclose(CLIFFORD[2], arr): - return _CliffordMeasure("Y", +1) - if np.allclose(-1 * CLIFFORD[2], arr): - return _CliffordMeasure("Y", -1) - if np.allclose(CLIFFORD[3], arr): - return _CliffordMeasure("Z", +1) - if np.allclose(-1 * CLIFFORD[3], arr): - return _CliffordMeasure("Z", -1) - msg = "No Pauli found" - raise ValueError(msg) - - @staticmethod - def clifford_index(g: npt.NDArray[np.complex128]) -> int: - """Return the index of Clifford for a given 2x2 matrix. - - Compare the gate g with all Clifford gates (up to global phase) and return the matching index. - - Parameters - ---------- - g : 2x2 numpy array. - - Returns - ------- - i : index of Clifford gate - """ - for i in range(24): - ci = CLIFFORD[i] - # normalise global phase - norm = g[0, 1] / ci[0, 1] if ci[0, 0] == 0 else g[0, 0] / ci[0, 0] - # compare - if np.allclose(ci * norm, g): - return i - msg = "No Clifford found" - raise ValueError(msg) - +class TestCliffordDB: @pytest.mark.parametrize(("i", "j"), itertools.product(range(24), range(3))) def test_measure(self, i: int, j: int) -> None: - conj = CLIFFORD[i].conjugate().T pauli = CLIFFORD[j + 1] - arr = np.matmul(np.matmul(conj, pauli), CLIFFORD[i]) - res = self.classify_pauli(arr) - assert res == CLIFFORD_MEASURE[i][j] + arr = CLIFFORD[i].conjugate().T @ pauli @ CLIFFORD[i] + sym, sgn = CLIFFORD_MEASURE[i][j] + arr_ = complex(sgn) * sym.matrix + assert np.allclose(arr, arr_) @pytest.mark.parametrize(("i", "j"), itertools.product(range(24), range(24))) def test_multiplication(self, i: int, j: int) -> None: - arr = np.matmul(CLIFFORD[i], CLIFFORD[j]) - assert CLIFFORD_MUL[i][j] == self.clifford_index(arr) + op = CLIFFORD[i] @ CLIFFORD[j] + assert Clifford.try_from_matrix(op) == Clifford(CLIFFORD_MUL[i][j]) @pytest.mark.parametrize("i", range(24)) def test_conjugation(self, i: int) -> None: - arr = CLIFFORD[i].conjugate().T - assert CLIFFORD_CONJ[i] == self.clifford_index(arr) + op = CLIFFORD[i].conjugate().T + assert Clifford.try_from_matrix(op) == Clifford(CLIFFORD_CONJ[i]) - @pytest.mark.parametrize("i", range(1, 24)) + @pytest.mark.parametrize("i", range(24)) def test_decomposition(self, i: int) -> None: op = np.eye(2, dtype=np.complex128) for j in CLIFFORD_HSZ_DECOMPOSITION[i]: op = op @ CLIFFORD[j] - assert i == self.clifford_index(op) - + assert Clifford.try_from_matrix(op) == Clifford(i) -class TestDB: @pytest.mark.parametrize("i", range(24)) def test_safety(self, i: int) -> None: with pytest.raises(TypeError): diff --git a/tests/test_density_matrix.py b/tests/test_density_matrix.py index b240a07f..7bf2d9a0 100644 --- a/tests/test_density_matrix.py +++ b/tests/test_density_matrix.py @@ -11,8 +11,8 @@ import graphix.random_objects as randobj from graphix.channels import KrausChannel, dephasing_channel, depolarising_channel +from graphix.fundamentals import Plane from graphix.ops import Ops -from graphix.pauli import Plane from graphix.sim.density_matrix import DensityMatrix, DensityMatrixBackend from graphix.sim.statevec import CNOT_TENSOR, CZ_TENSOR, SWAP_TENSOR, Statevec, StatevectorBackend from graphix.simulator import DefaultMeasureMethod diff --git a/tests/test_fundamentals.py b/tests/test_fundamentals.py new file mode 100644 index 00000000..a127a83e --- /dev/null +++ b/tests/test_fundamentals.py @@ -0,0 +1,163 @@ +from __future__ import annotations + +import itertools +import math + +import pytest + +from graphix.fundamentals import Axis, ComplexUnit, Plane, Sign + + +class TestSign: + def test_str(self) -> None: + assert str(Sign.PLUS) == "+" + assert str(Sign.MINUS) == "-" + + def test_plus_if(self) -> None: + assert Sign.plus_if(True) == Sign.PLUS + assert Sign.plus_if(False) == Sign.MINUS + + def test_minus_if(self) -> None: + assert Sign.minus_if(True) == Sign.MINUS + assert Sign.minus_if(False) == Sign.PLUS + + def test_neg(self) -> None: + assert -Sign.PLUS == Sign.MINUS + assert -Sign.MINUS == Sign.PLUS + + def test_mul_sign(self) -> None: + assert Sign.PLUS * Sign.PLUS == Sign.PLUS + assert Sign.PLUS * Sign.MINUS == Sign.MINUS + assert Sign.MINUS * Sign.PLUS == Sign.MINUS + assert Sign.MINUS * Sign.MINUS == Sign.PLUS + + def test_mul_int(self) -> None: + left = Sign.PLUS * 1 + assert isinstance(left, int) + assert left == int(Sign.PLUS) + right = 1 * Sign.PLUS + assert isinstance(right, int) + assert right == int(Sign.PLUS) + + left = Sign.MINUS * 1 + assert isinstance(left, int) + assert left == int(Sign.MINUS) + right = 1 * Sign.MINUS + assert isinstance(right, int) + assert right == int(Sign.MINUS) + + def test_mul_float(self) -> None: + left = Sign.PLUS * 1.0 + assert isinstance(left, float) + assert left == float(Sign.PLUS) + right = 1.0 * Sign.PLUS + assert isinstance(right, float) + assert right == float(Sign.PLUS) + + left = Sign.MINUS * 1.0 + assert isinstance(left, float) + assert left == float(Sign.MINUS) + right = 1.0 * Sign.MINUS + assert isinstance(right, float) + assert right == float(Sign.MINUS) + + def test_mul_complex(self) -> None: + left = Sign.PLUS * complex(1) + assert isinstance(left, complex) + assert left == complex(Sign.PLUS) + right = complex(1) * Sign.PLUS + assert isinstance(right, complex) + assert right == complex(Sign.PLUS) + + left = Sign.MINUS * complex(1) + assert isinstance(left, complex) + assert left == complex(Sign.MINUS) + right = complex(1) * Sign.MINUS + assert isinstance(right, complex) + assert right == complex(Sign.MINUS) + + def test_int(self) -> None: + # Necessary to justify `type: ignore` + assert isinstance(int(Sign.PLUS), int) + assert isinstance(int(Sign.MINUS), int) + + +class TestComplexUnit: + def test_try_from(self) -> None: + assert ComplexUnit.try_from(ComplexUnit.ONE) == ComplexUnit.ONE + assert ComplexUnit.try_from(1) == ComplexUnit.ONE + assert ComplexUnit.try_from(1.0) == ComplexUnit.ONE + assert ComplexUnit.try_from(1.0 + 0.0j) == ComplexUnit.ONE + assert ComplexUnit.try_from(3) is None + + def test_from_properties(self) -> None: + assert ComplexUnit.from_properties() == ComplexUnit.ONE + assert ComplexUnit.from_properties(is_imag=True) == ComplexUnit.J + assert ComplexUnit.from_properties(sign=Sign.MINUS) == ComplexUnit.MINUS_ONE + assert ComplexUnit.from_properties(sign=Sign.MINUS, is_imag=True) == ComplexUnit.MINUS_J + + @pytest.mark.parametrize(("sign", "is_imag"), itertools.product([Sign.PLUS, Sign.MINUS], [True, False])) + def test_properties(self, sign: Sign, is_imag: bool) -> None: + assert ComplexUnit.from_properties(sign=sign, is_imag=is_imag).sign == sign + assert ComplexUnit.from_properties(sign=sign, is_imag=is_imag).is_imag == is_imag + + def test_complex(self) -> None: + assert complex(ComplexUnit.ONE) == 1 + assert complex(ComplexUnit.J) == 1j + assert complex(ComplexUnit.MINUS_ONE) == -1 + assert complex(ComplexUnit.MINUS_J) == -1j + + def test_str(self) -> None: + assert str(ComplexUnit.ONE) == "1" + assert str(ComplexUnit.J) == "1j" + assert str(ComplexUnit.MINUS_ONE) == "-1" + assert str(ComplexUnit.MINUS_J) == "-1j" + + @pytest.mark.parametrize(("lhs", "rhs"), itertools.product(ComplexUnit, ComplexUnit)) + def test_mul_self(self, lhs: ComplexUnit, rhs: ComplexUnit) -> None: + assert complex(lhs * rhs) == complex(lhs) * complex(rhs) + + def test_mul_number(self) -> None: + assert ComplexUnit.ONE * 1 == ComplexUnit.ONE + assert 1 * ComplexUnit.ONE == ComplexUnit.ONE + assert ComplexUnit.ONE * 1.0 == ComplexUnit.ONE + assert 1.0 * ComplexUnit.ONE == ComplexUnit.ONE + assert ComplexUnit.ONE * complex(1) == ComplexUnit.ONE + assert complex(1) * ComplexUnit.ONE == ComplexUnit.ONE + + def test_neg(self) -> None: + assert -ComplexUnit.ONE == ComplexUnit.MINUS_ONE + assert -ComplexUnit.J == ComplexUnit.MINUS_J + assert -ComplexUnit.MINUS_ONE == ComplexUnit.ONE + assert -ComplexUnit.MINUS_J == ComplexUnit.J + + +_PLANE_INDEX = {Axis.X: 0, Axis.Y: 1, Axis.Z: 2} + + +class TestPlane: + @pytest.mark.parametrize("p", Plane) + def test_polar_consistency(self, p: Plane) -> None: + icos = _PLANE_INDEX[p.cos] + isin = _PLANE_INDEX[p.sin] + irest = 3 - icos - isin + po = p.polar(1) + assert po[icos] == pytest.approx(math.cos(1)) + assert po[isin] == pytest.approx(math.sin(1)) + assert po[irest] == 0 + + def test_from_axes(self) -> None: + assert Plane.from_axes(Axis.X, Axis.Y) == Plane.XY + assert Plane.from_axes(Axis.Y, Axis.Z) == Plane.YZ + assert Plane.from_axes(Axis.X, Axis.Z) == Plane.XZ + assert Plane.from_axes(Axis.Y, Axis.X) == Plane.XY + assert Plane.from_axes(Axis.Z, Axis.Y) == Plane.YZ + assert Plane.from_axes(Axis.Z, Axis.X) == Plane.XZ + + def test_from_axes_ng(self) -> None: + with pytest.raises(ValueError): + Plane.from_axes(Axis.X, Axis.X) + with pytest.raises(ValueError): + Plane.from_axes(Axis.Y, Axis.Y) + with pytest.raises(ValueError): + Plane.from_axes(Axis.Z, Axis.Z) diff --git a/tests/test_generator.py b/tests/test_generator.py index 1f5229cd..f03077cf 100644 --- a/tests/test_generator.py +++ b/tests/test_generator.py @@ -6,8 +6,8 @@ import numpy as np import pytest +from graphix.fundamentals import Plane from graphix.generator import generate_from_graph -from graphix.pauli import Plane from graphix.random_objects import rand_gate if TYPE_CHECKING: diff --git a/tests/test_gflow.py b/tests/test_gflow.py index acef5b4a..e9ae212b 100644 --- a/tests/test_gflow.py +++ b/tests/test_gflow.py @@ -9,6 +9,7 @@ from numpy.random import PCG64, Generator from graphix import command +from graphix.fundamentals import Plane from graphix.gflow import ( find_flow, find_gflow, @@ -19,7 +20,6 @@ verify_pauliflow, ) from graphix.pattern import Pattern -from graphix.pauli import Plane from graphix.random_objects import rand_circuit if TYPE_CHECKING: diff --git a/tests/test_graphsim.py b/tests/test_graphsim.py index da451506..4abd8a3f 100644 --- a/tests/test_graphsim.py +++ b/tests/test_graphsim.py @@ -10,10 +10,10 @@ from networkx.utils import graphs_equal from graphix.clifford import Clifford +from graphix.fundamentals import Plane from graphix.graphsim.graphstate import GraphState from graphix.graphsim.utils import convert_rustworkx_to_networkx, is_graphs_equal from graphix.ops import Ops -from graphix.pauli import Plane from graphix.sim.statevec import Statevec with contextlib.suppress(ModuleNotFoundError): diff --git a/tests/test_opengraph.py b/tests/test_opengraph.py index 78d6ab74..6e68e82c 100644 --- a/tests/test_opengraph.py +++ b/tests/test_opengraph.py @@ -2,8 +2,9 @@ import networkx as nx -from graphix.opengraph import Measurement, OpenGraph -from graphix.pauli import Plane +from graphix.fundamentals import Plane +from graphix.measurements import Measurement +from graphix.opengraph import OpenGraph # Tests whether an open graph can be converted to and from a pattern and be diff --git a/tests/test_pattern.py b/tests/test_pattern.py index 6679abcb..8ab600db 100644 --- a/tests/test_pattern.py +++ b/tests/test_pattern.py @@ -12,8 +12,9 @@ from graphix.clifford import Clifford from graphix.command import C, CommandKind, E, M, N, X, Z +from graphix.fundamentals import Plane +from graphix.measurements import PauliMeasurement from graphix.pattern import CommandNode, Pattern, shift_outcomes -from graphix.pauli import PauliMeasurement, Plane from graphix.random_objects import rand_circuit, rand_gate from graphix.sim.density_matrix import DensityMatrix from graphix.simulator import PatternSimulator diff --git a/tests/test_pauli.py b/tests/test_pauli.py index 860b0583..13b22fc9 100644 --- a/tests/test_pauli.py +++ b/tests/test_pauli.py @@ -5,69 +5,102 @@ import numpy as np import pytest -from graphix import pauli -from graphix.clifford import Clifford -from graphix.command import MeasureUpdate -from graphix.pauli import UNITS, ComplexUnit, Pauli, Plane +from graphix.fundamentals import Axis, ComplexUnit, Sign +from graphix.pauli import Pauli class TestPauli: + def test_from_axis(self) -> None: + assert Pauli.from_axis(Axis.X) == Pauli.X + assert Pauli.from_axis(Axis.Y) == Pauli.Y + assert Pauli.from_axis(Axis.Z) == Pauli.Z + + def test_axis(self) -> None: + with pytest.raises(ValueError): + _ = Pauli.I.axis + assert Pauli.X.axis == Axis.X + assert Pauli.Y.axis == Axis.Y + assert Pauli.Z.axis == Axis.Z + @pytest.mark.parametrize( ("u", "p"), - itertools.product( - UNITS, - pauli.LIST, - ), + itertools.product(ComplexUnit, Pauli), ) def test_unit_mul(self, u: ComplexUnit, p: Pauli) -> None: assert np.allclose((u * p).matrix, complex(u) * p.matrix) @pytest.mark.parametrize( ("a", "b"), - itertools.product( - pauli.LIST, - pauli.LIST, - ), + itertools.product(Pauli, Pauli), ) def test_matmul(self, a: Pauli, b: Pauli) -> None: assert np.allclose((a @ b).matrix, a.matrix @ b.matrix) - @pytest.mark.parametrize( - ("plane", "s", "t", "clifford", "angle", "choice"), - itertools.product( - Plane, - (False, True), - (False, True), - Clifford, - (0, np.pi), - (False, True), - ), - ) - def test_measure_update( - self, - plane: Plane, - s: bool, - t: bool, - clifford: Clifford, - angle: float, - choice: bool, - ) -> None: - measure_update = MeasureUpdate.compute(plane, s, t, clifford) - new_angle = angle * measure_update.coeff + measure_update.add_term - vec = measure_update.new_plane.polar(new_angle) - op_mat = np.eye(2, dtype=np.complex128) / 2 - for i in range(3): - op_mat += (-1) ** (choice) * vec[i] * Clifford(i + 1).matrix / 2 - - if s: - clifford = Clifford.X @ clifford - if t: - clifford = Clifford.Z @ clifford - vec = plane.polar(angle) - op_mat_ref = np.eye(2, dtype=np.complex128) / 2 - for i in range(3): - op_mat_ref += (-1) ** (choice) * vec[i] * Clifford(i + 1).matrix / 2 - clifford_mat = clifford.matrix - op_mat_ref = clifford_mat.conj().T @ op_mat_ref @ clifford_mat - - assert np.allclose(op_mat, op_mat_ref) or np.allclose(op_mat, -op_mat_ref) + @pytest.mark.parametrize("p", Pauli.iterate(symbol_only=True)) + def test_repr(self, p: Pauli) -> None: + pstr = f"Pauli.{p.symbol.name}" + assert repr(p) == pstr + assert repr(1 * p) == pstr + assert repr(1j * p) == f"1j * {pstr}" + assert repr(-1 * p) == f"-{pstr}" + assert repr(-1j * p) == f"-1j * {pstr}" + + @pytest.mark.parametrize("p", Pauli.iterate(symbol_only=True)) + def test_str(self, p: Pauli) -> None: + pstr = p.symbol.name + assert str(p) == pstr + assert str(1 * p) == pstr + assert str(1j * p) == f"1j * {pstr}" + assert str(-1 * p) == f"-{pstr}" + assert str(-1j * p) == f"-1j * {pstr}" + + @pytest.mark.parametrize("p", Pauli) + def test_neg(self, p: Pauli) -> None: + pneg = -p + assert pneg == -p + + def test_iterate_true(self) -> None: + cmp = list(Pauli.iterate(symbol_only=True)) + assert len(cmp) == 4 + assert cmp[0] == Pauli.I + assert cmp[1] == Pauli.X + assert cmp[2] == Pauli.Y + assert cmp[3] == Pauli.Z + + def test_iterate_false(self) -> None: + cmp = list(Pauli.iterate(symbol_only=False)) + assert len(cmp) == 16 + assert cmp[0] == Pauli.I + assert cmp[1] == 1j * Pauli.I + assert cmp[2] == -1 * Pauli.I + assert cmp[3] == -1j * Pauli.I + assert cmp[4] == Pauli.X + assert cmp[5] == 1j * Pauli.X + assert cmp[6] == -1 * Pauli.X + assert cmp[7] == -1j * Pauli.X + assert cmp[8] == Pauli.Y + assert cmp[9] == 1j * Pauli.Y + assert cmp[10] == -1 * Pauli.Y + assert cmp[11] == -1j * Pauli.Y + assert cmp[12] == Pauli.Z + assert cmp[13] == 1j * Pauli.Z + assert cmp[14] == -1 * Pauli.Z + assert cmp[15] == -1j * Pauli.Z + + def test_iter_meta(self) -> None: + it = Pauli.iterate(symbol_only=False) + it_ = iter(Pauli) + for p, p_ in zip(it, it_): + assert p == p_ + assert all(False for _ in it) + assert all(False for _ in it_) + + @pytest.mark.parametrize(("p", "b"), itertools.product(Pauli.iterate(symbol_only=True), [0, 1])) + def test_eigenstate(self, p: Pauli, b: int) -> None: + ev = float(Sign.plus_if(b == 0)) if p != Pauli.I else 1 + evec = p.eigenstate(b).get_statevector() + assert np.allclose(p.matrix @ evec, ev * evec) + + def test_eigenstate_invalid(self) -> None: + with pytest.raises(ValueError): + _ = Pauli.I.eigenstate(2) diff --git a/tests/test_statevec.py b/tests/test_statevec.py index 16dbaa78..855ad0f0 100644 --- a/tests/test_statevec.py +++ b/tests/test_statevec.py @@ -6,7 +6,7 @@ import numpy as np import pytest -from graphix.pauli import Plane +from graphix.fundamentals import Plane from graphix.sim.statevec import Statevec from graphix.states import BasicStates, PlanarState diff --git a/tests/test_statevec_backend.py b/tests/test_statevec_backend.py index 4aa2ae24..3a83ce20 100644 --- a/tests/test_statevec_backend.py +++ b/tests/test_statevec_backend.py @@ -6,10 +6,10 @@ import numpy as np import pytest -from graphix import pauli from graphix.clifford import Clifford -from graphix.pauli import Plane -from graphix.sim.base_backend import MeasurementDescription +from graphix.fundamentals import Plane +from graphix.measurements import Measurement +from graphix.pauli import Pauli from graphix.sim.statevec import Statevec, StatevectorBackend from graphix.states import BasicStates, PlanarState from tests.test_graphsim import meas_op @@ -81,7 +81,7 @@ def test_init_success(self, hadamardpattern, fx_rng: Generator) -> None: # random planar state rand_angle = fx_rng.random() * 2 * np.pi - rand_plane = fx_rng.choice(np.array([i for i in Plane])) + rand_plane = fx_rng.choice(np.array(Plane)) state = PlanarState(rand_plane, rand_angle) backend = StatevectorBackend() backend.add_nodes(hadamardpattern.input_nodes, data=state) @@ -94,7 +94,7 @@ def test_init_success(self, hadamardpattern, fx_rng: Generator) -> None: def test_init_fail(self, hadamardpattern, fx_rng: Generator) -> None: rand_angle = fx_rng.random(2) * 2 * np.pi - rand_plane = fx_rng.choice(np.array([i for i in Plane]), 2) + rand_plane = fx_rng.choice(np.array(Plane), 2) state = PlanarState(rand_plane[0], rand_angle[0]) state2 = PlanarState(rand_plane[1], rand_angle[1]) @@ -119,16 +119,16 @@ def test_deterministic_measure_one(self, fx_rng: Generator): coins = [fx_rng.choice([0, 1]), fx_rng.choice([0, 1])] expected_result = sum(coins) % 2 states = [ - pauli.X.get_eigenstate(eigenvalue=coins[0]), - pauli.Z.get_eigenstate(eigenvalue=coins[1]), + Pauli.X.eigenstate(coins[0]), + Pauli.Z.eigenstate(coins[1]), ] nodes = range(len(states)) backend.add_nodes(nodes=nodes, data=states) backend.entangle_nodes(edge=(nodes[0], nodes[1])) - measurement_description = MeasurementDescription(plane=Plane.XY, angle=0) + measurement = Measurement(0, Plane.XY) node_to_measure = backend.node_index[0] - result = backend.measure(node=node_to_measure, measurement_description=measurement_description) + result = backend.measure(node=node_to_measure, measurement=measurement) assert result == expected_result def test_deterministic_measure(self): @@ -137,15 +137,15 @@ def test_deterministic_measure(self): # plus state (default) backend = StatevectorBackend() n_neighbors = 10 - states = [pauli.X.get_eigenstate()] + [pauli.Z.get_eigenstate() for i in range(n_neighbors)] + states = [Pauli.X.eigenstate()] + [Pauli.Z.eigenstate() for i in range(n_neighbors)] nodes = range(len(states)) backend.add_nodes(nodes=nodes, data=states) for i in range(1, n_neighbors + 1): backend.entangle_nodes(edge=(nodes[0], i)) - measurement_description = MeasurementDescription(plane=Plane.XY, angle=0) + measurement = Measurement(0, Plane.XY) node_to_measure = backend.node_index[0] - result = backend.measure(node=node_to_measure, measurement_description=measurement_description) + result = backend.measure(node=node_to_measure, measurement=measurement) assert result == 0 assert list(backend.node_index) == list(range(1, n_neighbors + 1)) @@ -157,9 +157,9 @@ def test_deterministic_measure_many(self): n_traps = 5 n_neighbors = 5 n_whatever = 5 - traps = [pauli.X.get_eigenstate() for _ in range(n_traps)] - dummies = [pauli.Z.get_eigenstate() for _ in range(n_neighbors)] - others = [pauli.I.get_eigenstate() for _ in range(n_whatever)] + traps = [Pauli.X.eigenstate() for _ in range(n_traps)] + dummies = [Pauli.Z.eigenstate() for _ in range(n_neighbors)] + others = [Pauli.I.eigenstate() for _ in range(n_whatever)] states = traps + dummies + others nodes = range(len(states)) backend.add_nodes(nodes=nodes, data=states) @@ -171,11 +171,11 @@ def test_deterministic_measure_many(self): backend.entangle_nodes(edge=(other, dummy)) # Same measurement for all traps - measurement_description = MeasurementDescription(plane=Plane.XY, angle=0) + measurement = Measurement(0, Plane.XY) for trap in nodes[:n_traps]: node_to_measure = trap - result = backend.measure(node=node_to_measure, measurement_description=measurement_description) + result = backend.measure(node=node_to_measure, measurement=measurement) assert result == 0 assert list(backend.node_index) == list(range(n_traps, n_neighbors + n_traps + n_whatever)) @@ -191,16 +191,14 @@ def test_deterministic_measure_with_coin(self, fx_rng: Generator): n_neighbors = 10 coins = [fx_rng.choice([0, 1])] + [fx_rng.choice([0, 1]) for _ in range(n_neighbors)] expected_result = sum(coins) % 2 - states = [pauli.X.get_eigenstate(eigenvalue=coins[0])] + [ - pauli.Z.get_eigenstate(eigenvalue=coins[i + 1]) for i in range(n_neighbors) - ] + states = [Pauli.X.eigenstate(coins[0])] + [Pauli.Z.eigenstate(coins[i + 1]) for i in range(n_neighbors)] nodes = range(len(states)) backend.add_nodes(nodes=nodes, data=states) for i in range(1, n_neighbors + 1): backend.entangle_nodes(edge=(nodes[0], i)) - measurement_description = MeasurementDescription(plane=Plane.XY, angle=0) + measurement = Measurement(0, Plane.XY) node_to_measure = backend.node_index[0] - result = backend.measure(node=node_to_measure, measurement_description=measurement_description) + result = backend.measure(node=node_to_measure, measurement=measurement) assert result == expected_result assert list(backend.node_index) == list(range(1, n_neighbors + 1)) diff --git a/tests/test_transpiler.py b/tests/test_transpiler.py index 1534c61a..047242f6 100644 --- a/tests/test_transpiler.py +++ b/tests/test_transpiler.py @@ -4,7 +4,7 @@ import pytest from numpy.random import PCG64, Generator -from graphix.pauli import Plane +from graphix.fundamentals import Plane from graphix.random_objects import rand_circuit, rand_gate from graphix.transpiler import Circuit