Skip to content

Commit

Permalink
Merge pull request #39 from ksunden/edges
Browse files Browse the repository at this point in the history
Initial rework as conversion edges
  • Loading branch information
ksunden authored Mar 7, 2024
2 parents 830ce25 + 3c0c6ff commit a38c182
Show file tree
Hide file tree
Showing 15 changed files with 520 additions and 174 deletions.
4 changes: 2 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,14 @@ default_language_version:
python: python3
repos:
- repo: https://github.com/ambv/black
rev: 23.3.0
rev: 24.2.0
hooks:
- id: black
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v2.0.0
hooks:
- id: flake8
- repo: https://github.com/kynan/nbstripout
rev: 0.6.1
rev: 0.7.1
hooks:
- id: nbstripout
28 changes: 17 additions & 11 deletions data_prototype/axes.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,22 +102,28 @@ def scatter(
pipeline.append(lambda x: np.ma.ravel(x))
pipeline.append(lambda y: np.ma.ravel(y))
pipeline.append(
lambda s: np.ma.ravel(s)
if s is not None
else [20]
if mpl.rcParams["_internal.classic_mode"]
else [mpl.rcParams["lines.markersize"] ** 2.0]
lambda s: (
np.ma.ravel(s)
if s is not None
else (
[20]
if mpl.rcParams["_internal.classic_mode"]
else [mpl.rcParams["lines.markersize"] ** 2.0]
)
)
)
# TODO plotnonfinite/mask combining
pipeline.append(
lambda marker: marker
if marker is not None
else mpl.rcParams["scatter.marker"]
lambda marker: (
marker if marker is not None else mpl.rcParams["scatter.marker"]
)
)
pipeline.append(
lambda marker: marker
if isinstance(marker, mmarkers.MarkerStyle)
else mmarkers.MarkerStyle(marker)
lambda marker: (
marker
if isinstance(marker, mmarkers.MarkerStyle)
else mmarkers.MarkerStyle(marker)
)
)
pipeline.append(
FunctionConversionNode.from_funcs(
Expand Down
211 changes: 85 additions & 126 deletions data_prototype/containers.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from dataclasses import dataclass
from __future__ import annotations

from typing import (
Protocol,
Dict,
Expand All @@ -8,7 +9,6 @@
Union,
Callable,
MutableMapping,
TypeAlias,
)
import uuid

Expand All @@ -17,92 +17,25 @@
import numpy as np
import pandas as pd

from .description import Desc, desc_like

class _MatplotlibTransform(Protocol):
def transform(self, verts):
...

def __sub__(self, other) -> "_MatplotlibTransform":
...
from typing import TYPE_CHECKING

if TYPE_CHECKING:
from .conversion_edge import Graph

ShapeSpec: TypeAlias = Tuple[Union[str, int], ...]

class _MatplotlibTransform(Protocol):
def transform(self, verts): ...

@dataclass(frozen=True)
class Desc:
# TODO: sort out how to actually spell this. We need to know:
# - what the number of dimensions is (1d vs 2d vs ...)
# - is this a fixed size dimension (e.g. 2 for xextent)
# - is this a variable size depending on the query (e.g. N)
# - what is the relative size to the other variable values (N vs N+1)
# We are probably going to have to implement a DSL for this (😞)
shape: ShapeSpec
# TODO: is using a string better?
dtype: np.dtype
# TODO: do we want to include this at this level? "naive" means unit-unaware.
units: str = "naive"

@staticmethod
def validate_shapes(
specification: dict[str, ShapeSpec | "Desc"],
actual: dict[str, ShapeSpec | "Desc"],
*,
broadcast=False,
) -> bool:
specvars: dict[str, int | tuple[str, int]] = {}
for fieldname in specification:
spec = specification[fieldname]
if fieldname not in actual:
raise KeyError(
f"Actual is missing {fieldname!r}, required by specification."
)
desc = actual[fieldname]
if isinstance(spec, Desc):
spec = spec.shape
if isinstance(desc, Desc):
desc = desc.shape
if not broadcast:
if len(spec) != len(desc):
raise ValueError(
f"{fieldname!r} shape {desc} incompatible with specification "
f"{spec}."
)
elif len(desc) > len(spec):
raise ValueError(
f"{fieldname!r} shape {desc} incompatible with specification "
f"{spec}."
)
for speccomp, desccomp in zip(spec[::-1], desc[::-1]):
if broadcast and desccomp == 1:
continue
if isinstance(speccomp, str):
specv, specoff = speccomp[0], int(speccomp[1:] or 0)

if isinstance(desccomp, str):
descv, descoff = desccomp[0], int(desccomp[1:] or 0)
entry = (descv, descoff - specoff)
else:
entry = desccomp - specoff

if specv in specvars and entry != specvars[specv]:
raise ValueError(f"Found two incompatible values for {specv!r}")

specvars[specv] = entry
elif speccomp != desccomp:
raise ValueError(
f"{fieldname!r} shape {desc} incompatible with specification "
f"{spec}"
)
return None
def __sub__(self, other) -> "_MatplotlibTransform": ...


class DataContainer(Protocol):
def query(
self,
# TODO 3D?!!
coord_transform: _MatplotlibTransform,
size: Tuple[int, int],
graph: Graph,
parent_coordinates: str = "axes",
/,
) -> Tuple[Dict[str, Any], Union[str, int]]:
"""
Expand Down Expand Up @@ -132,6 +65,7 @@ def query(
This is a key that clients can use to cache down-stream
computations on this data.
"""
...

def describe(self) -> Dict[str, Desc]:
"""
Expand All @@ -141,27 +75,29 @@ def describe(self) -> Dict[str, Desc]:
-------
Dict[str, Desc]
"""
...


class NoNewKeys(ValueError):
...
class NoNewKeys(ValueError): ...


class ArrayContainer:
def __init__(self, **data):
self._data = data
self._cache_key = str(uuid.uuid4())
self._desc = {
k: Desc(v.shape, v.dtype)
if isinstance(v, np.ndarray)
else Desc((), type(v))
k: (
Desc(v.shape, v.dtype)
if isinstance(v, np.ndarray)
else Desc((), type(v))
)
for k, v in data.items()
}

def query(
self,
coord_transform: _MatplotlibTransform,
size: Tuple[int, int],
graph: Graph,
parent_coordinates: str = "axes",
) -> Tuple[Dict[str, Any], Union[str, int]]:
return dict(self._data), self._cache_key

Expand All @@ -185,8 +121,8 @@ def __init__(self, **shapes):

def query(
self,
coord_transform: _MatplotlibTransform,
size: Tuple[int, int],
graph: Graph,
parent_coordinates: str = "axes",
) -> Tuple[Dict[str, Any], Union[str, int]]:
return {k: np.random.randn(*d.shape) for k, d in self._desc.items()}, str(
uuid.uuid4()
Expand Down Expand Up @@ -253,31 +189,44 @@ def _query_hash(self, coord_transform, size):

def query(
self,
coord_transform: _MatplotlibTransform,
size: Tuple[int, int],
graph: Graph,
parent_coordinates: str = "axes",
) -> Tuple[Dict[str, Any], Union[str, int]]:
hash_key = self._query_hash(coord_transform, size)
if hash_key in self._cache:
return self._cache[hash_key], hash_key
# hash_key = self._query_hash(coord_transform, size)
# if hash_key in self._cache:
# return self._cache[hash_key], hash_key

desc = Desc(("N",), np.dtype("f8"))
xy = {"x": desc, "y": desc}
data_lim = graph.evaluator(
desc_like(xy, coordinates="data"),
desc_like(xy, coordinates=parent_coordinates),
).inverse

screen_size = graph.evaluator(
desc_like(xy, coordinates=parent_coordinates),
desc_like(xy, coordinates="display"),
)

xpix, ypix = size
x_data, _ = coord_transform.transform(
np.vstack(
[
np.linspace(0, 1, int(xpix) * 2),
np.zeros(int(xpix) * 2),
]
).T
).T
_, y_data = coord_transform.transform(
np.vstack(
[
np.zeros(int(ypix) * 2),
np.linspace(0, 1, int(ypix) * 2),
]
).T
).T
screen_dims = screen_size.evaluate({"x": [0, 1], "y": [0, 1]})
xpix, ypix = np.ceil(np.abs(np.diff(screen_dims["x"]))), np.ceil(
np.abs(np.diff(screen_dims["y"]))
)

x_data = data_lim.evaluate(
{
"x": np.linspace(0, 1, int(xpix) * 2),
"y": np.zeros(int(xpix) * 2),
}
)["x"]
y_data = data_lim.evaluate(
{
"x": np.zeros(int(ypix) * 2),
"y": np.linspace(0, 1, int(ypix) * 2),
}
)["y"]

hash_key = str(uuid.uuid4())
ret = self._cache[hash_key] = dict(
**{k: f(x_data) for k, f in self._xfuncs.items()},
**{k: f(y_data) for k, f in self._yfuncs.items()},
Expand All @@ -302,11 +251,21 @@ def __init__(self, raw_data, num_bins: int):

def query(
self,
coord_transform: _MatplotlibTransform,
size: Tuple[int, int],
graph: Graph,
parent_coordinates: str = "axes",
) -> Tuple[Dict[str, Any], Union[str, int]]:
dmin, dmax = self._full_range
xmin, ymin, xmax, ymax = coord_transform.transform([[0, 0], [1, 1]]).flatten()

desc = Desc(("N",), np.dtype("f8"))
xy = {"x": desc, "y": desc}
data_lim = graph.evaluator(
desc_like(xy, coordinates="data"),
desc_like(xy, coordinates=parent_coordinates),
).inverse

pts = data_lim.evaluate({"x": (0, 1), "y": (0, 1)})
xmin, xmax = pts["x"]
ymin, ymax = pts["y"]

xmin, xmax = np.clip([xmin, xmax], dmin, dmax)
hash_key = hash((xmin, xmax))
Expand All @@ -333,7 +292,7 @@ def describe(self) -> Dict[str, Desc]:


class SeriesContainer:
_data: pd.DataFrame
_data: pd.Series
_index_name: str
_hash_key: str

Expand All @@ -350,8 +309,8 @@ def __init__(self, series: pd.Series, *, index_name: str, col_name: str):

def query(
self,
coord_transform: _MatplotlibTransform,
size: Tuple[int, int],
graph: Graph,
parent_coordinates: str = "axes",
) -> Tuple[Dict[str, Any], Union[str, int]]:
return {
self._index_name: self._data.index.values,
Expand Down Expand Up @@ -392,8 +351,8 @@ def __init__(

def query(
self,
coord_transform: _MatplotlibTransform,
size: Tuple[int, int],
graph: Graph,
parent_coordinates: str = "axes",
) -> Tuple[Dict[str, Any], Union[str, int]]:
ret = {}
if self._index_name is not None:
Expand All @@ -415,10 +374,10 @@ def __init__(self, data: DataContainer, mapping: Dict[str, str]):

def query(
self,
coord_transform: _MatplotlibTransform,
size: Tuple[int, int],
graph: Graph,
parent_coordinates: str = "axes",
) -> Tuple[Dict[str, Any], Union[str, int]]:
base, cache_key = self._data.query(coord_transform, size)
base, cache_key = self._data.query(graph, parent_coordinates)
return {v: base[k] for k, v in self._mapping.items()}, cache_key

def describe(self):
Expand All @@ -433,13 +392,13 @@ def __init__(self, *data: DataContainer):

def query(
self,
coord_transform: _MatplotlibTransform,
size: Tuple[int, int],
graph: Graph,
parent_coordinates: str = "axes",
) -> Tuple[Dict[str, Any], Union[str, int]]:
cache_keys = []
ret = {}
for data in self._datas:
base, cache_key = data.query(coord_transform, size)
base, cache_key = data.query(graph, parent_coordinates)
ret.update(base)
cache_keys.append(cache_key)
return ret, hash(tuple(cache_keys))
Expand All @@ -451,11 +410,11 @@ def describe(self):
class WebServiceContainer:
def query(
self,
coord_transform: _MatplotlibTransform,
size: Tuple[int, int],
graph: Graph,
parent_coordinates: str = "axes",
) -> Tuple[Dict[str, Any], Union[str, int]]:
def hit_some_database():
{}, "1"
return {}, "1"

data, etag = hit_some_database()
return data, etag
Loading

0 comments on commit a38c182

Please sign in to comment.