Skip to content

Commit

Permalink
🚨 Fix mypy errors
Browse files Browse the repository at this point in the history
  • Loading branch information
EarlMilktea committed Sep 18, 2024
1 parent 4633a16 commit 21f1194
Showing 1 changed file with 10 additions and 10 deletions.
20 changes: 10 additions & 10 deletions graphix/states.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,19 @@
from __future__ import annotations

import abc
from abc import ABC
from typing import ClassVar

import numpy as np
import numpy.typing as npt
import pydantic
import typing_extensions
from pydantic import BaseModel

from graphix.pauli import Plane


# generic class State for all States
class State(abc.ABC):
class State(ABC):
"""Abstract base class for single qubit states objects.
Only requirement for concrete classes is to have
Expand All @@ -23,17 +24,16 @@ class State(abc.ABC):
"""

@abc.abstractmethod
def get_statevector(self) -> npt.NDArray:
def get_statevector(self) -> npt.NDArray[np.complex128]:
"""Return the state vector."""
...

def get_densitymatrix(self) -> npt.NDArray:
def get_densitymatrix(self) -> npt.NDArray[np.complex128]:
"""Return the density matrix."""
# return DM in 2**n x 2**n dim (2x2 here)
return np.outer(self.get_statevector(), self.get_statevector().conj())


class PlanarState(pydantic.BaseModel, State):
class PlanarState(BaseModel, State):
"""Light object used to instantiate backends.
doesn't cover all possible states but this is
Expand All @@ -60,16 +60,16 @@ def __str__(self) -> str:
"""Return a string description of the planar state."""
return f"PlanarState object defined in plane {self.plane} with angle {self.angle}."

def get_statevector(self) -> npt.NDArray:
def get_statevector(self) -> npt.NDArray[np.complex128]:
"""Return the state vector."""
if self.plane == Plane.XY:
return np.array([1, np.exp(1j * self.angle)]) / np.sqrt(2)
return np.asarray([1 / np.sqrt(2), np.exp(1j * self.angle) / np.sqrt(2)], dtype=np.complex128)

if self.plane == Plane.YZ:
return np.array([np.cos(self.angle / 2), 1j * np.sin(self.angle / 2)])
return np.asarray([np.cos(self.angle / 2), 1j * np.sin(self.angle / 2)], dtype=np.complex128)

if self.plane == Plane.XZ:
return np.array([np.cos(self.angle / 2), np.sin(self.angle / 2)])
return np.asarray([np.cos(self.angle / 2), np.sin(self.angle / 2)], dtype=np.complex128)
# other case never happens since exhaustive
typing_extensions.assert_never(self.plane)

Expand Down

0 comments on commit 21f1194

Please sign in to comment.