Skip to content

Commit

Permalink
Added initial package setup
Browse files Browse the repository at this point in the history
  • Loading branch information
davystrong committed May 12, 2023
1 parent 195e301 commit 0eb7ead
Show file tree
Hide file tree
Showing 7 changed files with 366 additions and 0 deletions.
21 changes: 21 additions & 0 deletions LICENSE
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
MIT License

Copyright (c) 2023 David R. Armstrong

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
49 changes: 49 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
# Einsum Pipe

A Python package to compile multiple Numpy [einsum](https://numpy.org/doc/stable/reference/generated/numpy.einsum.html) operations into one.

## Example

Given two arrays:
```python
A = np.random.rand(32, 32, 10, 5)
B = np.random.rand(32, 32, 10, 5)
```

We frequently need to run multiple reshape/transpose/products/trace/etc., such as:
```python
C = np.einsum('ij...,kl...->ikjl...', A, B)
D = C.reshape([2, ]*20 + [10, 5])
E = D.transpose([2, 3, 4, 5, 6, 7, 8, 9, 12, 13, 14,
15, 16, 17, 18, 19, 0, 1, 10, 11, 20, 21])
F = E.reshape([256, 256, 4, 4, 10, 5])
X = np.trace(F)
```

This obviously results in multiple intermediate arrays, some of which can be large. Instead of doing this, it is possible to combine multiple `np.einsum` operations into one. By carefully modifying the input shape, it is even possible to do this in cases in which the intermediate data is reshaped during the process, provided the shapes are all [compatible](#shape-compatibility). The previous example can instead be performed in a single `np.einsum` step:
```
X = einsum_pipe(
'ik...,jl...->ijkl...',
[2, ]*20 + [10, 5],
'abcde fghij klmno pqrst...->cde fghij mno pqrst ab kl...',
[256, 256, 4, 4, 10, 5],
'ii...',
A, B
)
```

Internally, this calculates a compatible input shape, `(4, 8, 4, 8, 50)` and `(32, 32, 50)`, and a combined `np.einsum` set of subscripts, `"ebdbc,aac->edc"`. `A` and `B` are reshaped (which is generally essentially free), the single `np.einsum` operation is run, and the output is reshaped back to the expected output shape.

## Syntax

The syntax is based on Numpy's `einsum`, with the addition of allowing multiple subscripts and defining the shapes of the intermediate arrays. The input arrays can be put at the end, as shown, or next to the subscript definitions. In this example, only two arrays are used at start of the pipe, however you can add more arrays at later stages. The output of the previous step is always considered the first input of the subsequent step.

## Shape Compatibility

Shapes are compatible if each dimension is the product of some subsequence of a matching shape (of the previous output). For example, `(32, 32)` and `(4, 256)` are compatible, since both can be built from the shape `(4, 8, 4, 8)`: `(4*8, 4*8)` and `(4, 8*4*8)`. On the other hand, `(2, 3)` and `(3, 2)` aren't directly compatible since they don't share divisors.

Note that transposition of axes also causes the transposition of the compatible shape, so while `[(3, 2), 'ij->ij', (2, 3)]` isn't valid, `[(3, 2), 'ij->ji', (2, 3)]` is.

I plan to implement a best effort fallback which would reduce a sequence of operations to as few operations as possible, depending on incompatible shapes.

## Numpy Operations
25 changes: 25 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
[build-system]
requires = ["hatchling"]
build-backend = "hatchling.build"

[project]
name = "einsum_pipe"
version = "0.0.1"
authors = [
{ name="David R. Armstrong" },
]
description = "A Python package to compile multiple Numpy einsum operations into one"
readme = "README.md"
requires-python = ">=3.7"
classifiers = [
"Programming Language :: Python :: 3",
"License :: OSI Approved :: MIT License",
"Operating System :: OS Independent",
]
dependencies = [
"numpy"
]

[project.urls]
"Homepage" = "https://github.com/davystrong/Einsum-Pipe"
"Bug Tracker" = "https://github.com/davystrong/Einsum-Pipe/issues"
2 changes: 2 additions & 0 deletions src/einsum_pipe/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .einsum_script import EinsumScript, EinsumComp
from .einsum_pipe import einsum_pipe
37 changes: 37 additions & 0 deletions src/einsum_pipe/bidict.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
from typing import Dict, Generic, Iterator, TypeVar
from collections.abc import MutableMapping


K = TypeVar('K')
V = TypeVar('V')


class BiDict(MutableMapping, Generic[K, V]):
"""A custom dictionary that keeps track of the inverse mapping from values to keys. Values must be unique too
"""

def __init__(self):
self.store: Dict[K, V] = {}
self.inverse: Dict[V, K] = {}

def __getitem__(self, key: K) -> V:
return self.store[key]

def __setitem__(self, key: K, value: V) -> None:
if key in self:
del self.inverse[self[key]]
self.inverse[value] = key
self.store[key] = value

def __delitem__(self, key: K) -> None:
del self.inverse[self[key]]
del self.store[key]

def values(self):
return self.inverse.keys()

def __iter__(self) -> Iterator[K]:
return iter(self.store)

def __len__(self):
return len(self.store)
36 changes: 36 additions & 0 deletions src/einsum_pipe/einsum_pipe.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
from functools import reduce
from typing import List

import numpy as np
from .einsum_script import EinsumScript


def einsum_pipe(*args):
subs = [arg for arg in args if isinstance(arg, (str, list, tuple))]
ops = [arg for arg in args if not isinstance(arg, (str, list, tuple))]
ops_index = 0
scripts: List[EinsumScript] = []

while len(subs) > 0:
input_shapes = []
sub = subs.pop(0)
if not isinstance(sub, str):
input_shapes.append(sub)
sub = subs.pop(0)
assert isinstance(sub, str)

args = sub.count(',') + 1 - len(input_shapes)

input_shapes.extend(tuple(x.shape)
for x in ops[ops_index:ops_index+args])
ops_index += args

x = EinsumScript.parse(input_shapes, sub)
scripts.append(x)

output_script = reduce(lambda x, y: x+y, scripts)
output_script.simplify()
reshaped_ops = [np.reshape(op, [comp.size for comp in inp])
for op, inp in zip(ops, output_script.inputs)]
raw_output: np.ndarray = np.einsum(str(output_script), *reshaped_ops)
return raw_output.reshape([comp.size for comp in scripts[-1].outputs])
196 changes: 196 additions & 0 deletions src/einsum_pipe/einsum_script.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,196 @@
import copy
import math
from typing import Generator, List, Self, Tuple, TypeVar, Union, cast
from .bidict import BiDict


class EinsumComp:
def __init__(self, size: int) -> None:
self.size = size


class NullTag:
pass


class EinsumScript:
def __init__(self, inputs: List[List[EinsumComp]], outputs: List[EinsumComp]) -> None:
self.inputs = inputs
self.outputs = outputs

@classmethod
def parse(cls, input_shapes: List[List[int]], subscripts: str) -> Self:
subscripts = subscripts.replace(' ', '')
# Easier to deal with broadcasting as a single character
subscripts = subscripts.replace('...', '?')
# The broadcasting character is automatically sorted to the start
letters = sorted(subscripts.replace(',', '').replace('->', ''))
if '->' not in subscripts:
output_letters = [l for l in letters if l ==
'?' or letters.count(l) == 1]
subscripts += '->' + ''.join(output_letters)
letter_dict = {v: EinsumComp(0) for v in set(letters) if v != '?'}

inputs_subs, output_subs = subscripts.split('->')
inputs: List[List[EinsumComp]] = []
broadcast_comps: List[EinsumComp] = []
for sub, shape in zip(inputs_subs.split(','), input_shapes):
inputs.append([])
for c in sub:
if c == '?':
# Broadcasting works from the last axis to the first and shares these axes with other broadcasts
undefined_axes = len(shape) - (len(sub) - 1)
for _ in range(undefined_axes - len(broadcast_comps)):
broadcast_comps.insert(0, EinsumComp(0))
inputs[-1].extend(broadcast_comps[-undefined_axes:])
else:
inputs[-1].append(letter_dict[c])

outputs: List[EinsumComp] = []
for c in output_subs:
if c == '?':
# All broadcasted axes are added in order
outputs.extend(broadcast_comps)
else:
outputs.append(letter_dict[c])

script = EinsumScript(inputs, outputs)
for inp, shape in zip(inputs, input_shapes):
assert len(inp) == len(shape)
for comp, dim in zip(inp, shape):
comp.size = dim

return script

def split_comp(self, comp: EinsumComp, part_sizes: List[int]) -> None:
repeats = [EinsumComp(size) for size in part_sizes[1:]]
comp.size = part_sizes[0]
for inp in [*self.inputs, self.outputs]:
for i in range(len(inp)-1, -1, -1):
if inp[i] == comp:
for rep in repeats[::-1]:
inp.insert(i+1, rep)

def remove_ones(self):
for inp in [*self.inputs, self.outputs]:
for i in range(len(inp)-1, -1, -1):
if inp[i].size == 1:
inp.pop(i)

def transform_shapes(self, input_shapes: List[List[int]]) -> List[int]:
assert len(input_shapes) == len(self.inputs)
shape_dict = {sub: comp for subs, shape in zip(
self.inputs, input_shapes) for sub, comp in zip(subs, shape)}
return [shape_dict[out_sub] for out_sub in self.outputs]

def simplify(self):
next_map: BiDict[Union[NullTag, EinsumComp],
Union[NullTag, EinsumComp]] = BiDict()

for comps in [*self.inputs, self.outputs]:
prev = NullTag()
for comp in comps:
if prev in next_map:
if next_map[prev] != comp:
next_map[NullTag()] = next_map[prev]
next_map[NullTag()] = comp
next_map[prev] = NullTag()
elif comp in next_map.values():
# Don't need to check if key is already the same as this will be caught by the previous condition
key = next_map.inverse[comp]
next_map[key] = NullTag()
next_map[prev] = NullTag()
next_map[NullTag()] = comp
else:
next_map[prev] = comp
prev = comp
next_map[prev] = NullTag()

null_tags = [key for key in next_map if isinstance(key, NullTag)]
group_pairs: List[Tuple[List[EinsumComp], EinsumComp]] = []
for tag in null_tags:
seq: List[EinsumComp] = []
while not isinstance(next_map[tag], NullTag):
seq.append(cast(EinsumComp, next_map[tag]))
tag = next_map[tag]
if len(seq) > 1:
group_pairs.append(
(seq, EinsumComp(math.prod(comp.size for comp in seq))))

for comps in [*self.inputs, self.outputs]:
for group, new_comp in group_pairs:
while group[0] in comps:
i = comps.index(group[0])
comps[i] = new_comp
for _ in range(len(group) - 1):
comps.pop(i + 1)

def simplified(self) -> Self:
val = copy.deepcopy(self)
val.simplify()
return val

@staticmethod
def _get_char(index: int) -> str:
return chr((ord('a') if index < 26 else (ord('A') - 26)) + index)

def __str__(self) -> str:
comps = list(set(comp for inp in self.inputs for comp in inp))

subs = []
for inp in self.inputs:
subs.append(''.join(self._get_char(comps.index(comp))
for comp in inp))

output_str = ''.join(self._get_char(comps.index(comp))
for comp in self.outputs)

return ','.join(subs) + '->' + output_str

def __add__(self, rhs: Self) -> Self:
lhs = copy.deepcopy(self)
rhs = copy.deepcopy(rhs)
lhs_out_iter = rev_mut_iter(lhs.outputs)
rhs_in_iter = rev_mut_iter(rhs.inputs[0])

lhs_out_val = next(lhs_out_iter)
rhs_in_val = next(rhs_in_iter)

try:
while True:
if lhs_out_val.size == rhs_in_val.size:
lhs_out_val = next(lhs_out_iter)
rhs_in_val = next(rhs_in_iter)
elif lhs_out_val.size > rhs_in_val.size:
lhs.split_comp(lhs_out_val, [
lhs_out_val.size // rhs_in_val.size, rhs_in_val.size])
rhs_in_val = next(rhs_in_iter)
else:
rhs.split_comp(rhs_in_val, [
rhs_in_val.size // lhs_out_val.size, lhs_out_val.size])
lhs_out_val = next(lhs_out_iter)
except StopIteration:
pass

assert len(lhs.outputs) == len(rhs.inputs[0])
assert all(x.size == y.size for x, y in zip(
lhs.outputs, rhs.inputs[0]))

for i, x in enumerate(rhs.inputs[0]):
val = lhs.outputs[i]
lhs.outputs[i] = x
for inp in lhs.inputs:
if val in inp:
for j, y in enumerate(inp):
if y == val:
inp[j] = x

return EinsumScript(lhs.inputs + rhs.inputs[1:], rhs.outputs)


T = TypeVar('T')


def rev_mut_iter(data: List[T]) -> Generator[T, None, None]:
for i in range(len(data)-1, -1, -1):
yield data[i]

0 comments on commit 0eb7ead

Please sign in to comment.