Skip to content

Commit

Permalink
Support GraphBLAS expressions
Browse files Browse the repository at this point in the history
  • Loading branch information
alugowski committed Aug 31, 2023
1 parent 7f161d4 commit ea08eab
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 11 deletions.
22 changes: 18 additions & 4 deletions matrepr/adapters/graphblas_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,28 @@ def get_supported_types() -> Iterable[Tuple[str, bool]]:
("graphblas.core.matrix.Matrix", True),
("graphblas.Vector", True),
("graphblas.core.vector.Vector", True),
# Expressions:
("graphblas.core.expr.InfixExprBase", True),
("graphblas.core.infix.VectorInfixExpr", True),
("graphblas.core.infix.MatrixInfixExpr", True),
("graphblas.core.infix.VectorMatMulExpr", True),
("graphblas.core.infix.MatrixMatMulExpr", True),
("graphblas.core.vector.VectorExpression", True),
("graphblas.core.matrix.MatrixExpression", True),
]

@staticmethod
def adapt(mat: Any):
from .graphblas_impl import GraphBLASMatrixAdapter, GraphBLASVectorAdapter
if "Matrix" in str(type(mat)):
return GraphBLASMatrixAdapter(mat)
elif "Vector" in str(type(mat)):
return GraphBLASVectorAdapter(mat)
type_name = type(mat).__name__
if hasattr(mat, "_get_value"):
# This is an expression. Compute the value and format that.
# noinspection PyProtectedMember
mat = mat._get_value()

if "Matrix" == type(mat).__name__:
return GraphBLASMatrixAdapter(mat, type_name)
elif "Vector" == type(mat).__name__:
return GraphBLASVectorAdapter(mat, type_name)
else:
raise ValueError("Unknown type: " + str(type(mat)))
13 changes: 7 additions & 6 deletions matrepr/adapters/graphblas_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,9 @@


class GraphBLASAdapter:
def __init__(self, mat):
def __init__(self, mat, type_name):
self.mat = mat
self.type_name = type_name

def get_shape(self) -> tuple:
return self.mat.shape
Expand All @@ -30,7 +31,7 @@ def get_format(self, is_transposed=False):
return None

def describe(self) -> str:
parts = [f"gb.{type(self.mat).__name__}"]
parts = [f"gb.{self.type_name}"]

fmt = self.get_format()
if fmt:
Expand All @@ -42,8 +43,8 @@ def describe(self) -> str:


class GraphBLASMatrixAdapter(GraphBLASAdapter, MatrixAdapterCoo):
def __init__(self, mat: gb.Matrix):
super().__init__(mat)
def __init__(self, mat: gb.Matrix, type_name):
super().__init__(mat, type_name)

def get_coo(self, row_range: Tuple[int, int], col_range: Tuple[int, int]) -> Iterable[Tuple[int, int, Any]]:
ret = self.mat[slice(*row_range), slice(*col_range)]
Expand All @@ -55,8 +56,8 @@ def get_coo(self, row_range: Tuple[int, int], col_range: Tuple[int, int]) -> Ite


class GraphBLASVectorAdapter(GraphBLASAdapter, MatrixAdapterRow):
def __init__(self, vec: gb.Vector):
super().__init__(vec)
def __init__(self, vec: gb.Vector, type_name):
super().__init__(vec, type_name)

def get_row(self, row_idx: int, col_range: Tuple[int, int]) -> Iterable[Any]:
assert row_idx == 0
Expand Down
15 changes: 14 additions & 1 deletion matrepr/patch/graphblas.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,27 @@

from .. import *

from matrepr.adapters.graphblas_driver import GraphBLASDriver
import graphblas


def _str_(mat):
from matrepr.adapters.graphblas_driver import GraphBLASDriver
# Enable terminal width detection
return to_str(GraphBLASDriver.adapt(mat), width_str=0, max_cols=9999)


graphblas.Matrix.__repr__ = _str_
graphblas.Vector.__repr__ = _str_


def _patch_all_supported_types():
# noinspection PyProtectedMember
from pydoc import locate

for tp_str, _ in GraphBLASDriver.get_supported_types():
cls = locate(tp_str)
if cls:
cls.__repr__ = _str_


_patch_all_supported_types()

0 comments on commit ea08eab

Please sign in to comment.