diff --git a/pyzx/graph/base.py b/pyzx/graph/base.py index eb78cccd..0c8f2ad3 100644 --- a/pyzx/graph/base.py +++ b/pyzx/graph/base.py @@ -823,15 +823,15 @@ def num_vertices(self) -> int: """Returns the amount of vertices in the graph.""" raise NotImplementedError("Not implemented on backend " + type(self).backend) - def num_edges(self) -> int: + def num_edges(self, s: Optional[VT]=None, t: Optional[VT]=None) -> int: """Returns the amount of edges in the graph""" - raise NotImplementedError("Not implemented on backend " + type(self).backend) + return len(list(self.edges(s, t))) - def vertices(self) -> Sequence[VT]: + def vertices(self) -> Iterable[VT]: """Iterator over all the vertices.""" raise NotImplementedError("Not implemented on backend " + type(self).backend) - def edges(self, s: Optional[VT]=None, t: Optional[VT]=None) -> Sequence[ET]: + def edges(self, s: Optional[VT]=None, t: Optional[VT]=None) -> Iterable[ET]: """Iterator that returns all the edges in the graph, or all the edges connecting the pair of vertices. Output type depends on implementation in backend.""" raise NotImplementedError("Not implemented on backend " + type(self).backend) diff --git a/pyzx/graph/graph_s.py b/pyzx/graph/graph_s.py index 1f06d680..39ce5443 100644 --- a/pyzx/graph/graph_s.py +++ b/pyzx/graph/graph_s.py @@ -15,7 +15,7 @@ # limitations under the License. from fractions import Fraction -from typing import Tuple, Dict, Set, Any +from typing import Optional, Tuple, Dict, Set, Any from .base import BaseGraph @@ -202,9 +202,16 @@ def remove_edge(self, edge): def num_vertices(self): return len(self.graph) - def num_edges(self): - #return self.nedges - return len(self.edge_set()) + def num_edges(self, s=None, t=None): + if s is not None and t is not None: + if self.connected(s, t): + return 1 + else: + return 0 + elif s is not None: + return self.vertex_degree(s) + else: + return len(list(self.edges())) def vertices(self): return self.graph.keys() diff --git a/pyzx/graph/multigraph.py b/pyzx/graph/multigraph.py index b8e17179..fbf63d8a 100644 --- a/pyzx/graph/multigraph.py +++ b/pyzx/graph/multigraph.py @@ -255,9 +255,11 @@ def remove_edge(self, edge): def num_vertices(self): return len(self.graph) - def num_edges(self): - return self.nedges - #return len(self.edge_set()) + def num_edges(self, s=None, t=None): + if s != None or t != None: + return len(list(self.edges(s, t))) + else: + return self.nedges def vertices(self): return self.graph.keys() diff --git a/pyzx/pauliweb.py b/pyzx/pauliweb.py index 18ac1356..e0531509 100644 --- a/pyzx/pauliweb.py +++ b/pyzx/pauliweb.py @@ -44,13 +44,13 @@ def __init__(self, g: BaseGraph[VT,ET], c: Set[VT]): self.g = g self.c = c - def vertices(self): + def vertices(self) -> Set[VT]: vs = self.c.copy() for v in self.c: vs |= set(self.g.neighbors(v)) return vs - def half_edges(self): + def half_edges(self) -> Dict[Tuple[VT,VT],str]: es: Dict[Tuple[VT,VT],str] = dict() for v in self.c: for e in self.g.incident_edges(v): @@ -67,7 +67,7 @@ def half_edges(self): es[(v1,v)] = multiply_paulis(t2, ty) return es - def boundary(self): + def boundary(self) -> Set[VT]: b: Dict[VT, int] = dict() for v in self.c: for n in self.g.neighbors(v): @@ -136,8 +136,8 @@ def preprocess(g: BaseGraph[VT,ET]): return (in_circ, out_circ) -def transpose_corrections(c) -> Dict[VT, Set[VT]]: - ct = dict() +def transpose_corrections(c: Dict[VT, Set[VT]]) -> Dict[VT, Set[VT]]: + ct: Dict[VT, Set[VT]] = dict() for k,s in c.items(): for v in s: if v not in ct: ct[v] = set() diff --git a/pyzx/rules.py b/pyzx/rules.py index 9384cff9..490097c9 100644 --- a/pyzx/rules.py +++ b/pyzx/rules.py @@ -122,9 +122,9 @@ def match_bialg_parallel( v1n = [n for n in g.neighbors(v1) if not n == v0] if (all([types[n] == v1t and phases[n] == 0 for n in v0n]) and # all neighbors of v0 are of the same type as v1 all([types[n] == v0t and phases[n] == 0 for n in v1n]) and # all neighbors of v1 are of the same type as v0 - len(g.edges(v0, v1)) == 1 and # there is exactly one edge between v0 and v1 - len(g.edges(v0, v0)) == 0 and # there are no self-loops on v0 - len(g.edges(v1, v1)) == 0): # there are no self-loops on v1 + g.num_edges(v0, v1) == 1 and # there is exactly one edge between v0 and v1 + g.num_edges(v0, v0) == 0 and # there are no self-loops on v0 + g.num_edges(v1, v1) == 0): # there are no self-loops on v1 i += 1 for vn in [v0n, v1n]: for v in vn: diff --git a/pyzx/simplify.py b/pyzx/simplify.py index b11c0640..4f358221 100644 --- a/pyzx/simplify.py +++ b/pyzx/simplify.py @@ -318,8 +318,8 @@ def max_cut(g: BaseGraph[VT,ET], vs0: Optional[Set[VT]]=None, vs1: Optional[Set[ This uses the quadratic-time SG3 heuristic explained by Wang et al in https://arxiv.org/abs/2312.10895 . """ - if vs0 == None: vs0 = set() - if vs1 == None: vs1 = set() + if vs0 is None: vs0 = set() + if vs1 is None: vs1 = set() # print(f'vs0={vs0} vs1={vs1}') remaining = set(g.vertices()) - vs0 - vs1 while len(remaining) > 0: @@ -337,7 +337,7 @@ def max_cut(g: BaseGraph[VT,ET], vs0: Optional[Set[VT]]=None, vs1: Optional[Set[ in0 = wt0 >= wt1 # print(f'choosing {v_max} for set {"vs0" if in0 else "vs1"}') - if v_max == None: raise RuntimeError("No max found") + if v_max is None: raise RuntimeError("No max found") remaining.remove(v_max) if in0: vs0.add(v_max) else: vs1.add(v_max)