-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
195e301
commit 0eb7ead
Showing
7 changed files
with
366 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
MIT License | ||
|
||
Copyright (c) 2023 David R. Armstrong | ||
|
||
Permission is hereby granted, free of charge, to any person obtaining a copy | ||
of this software and associated documentation files (the "Software"), to deal | ||
in the Software without restriction, including without limitation the rights | ||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell | ||
copies of the Software, and to permit persons to whom the Software is | ||
furnished to do so, subject to the following conditions: | ||
|
||
The above copyright notice and this permission notice shall be included in all | ||
copies or substantial portions of the Software. | ||
|
||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR | ||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, | ||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE | ||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER | ||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, | ||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE | ||
SOFTWARE. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
# Einsum Pipe | ||
|
||
A Python package to compile multiple Numpy [einsum](https://numpy.org/doc/stable/reference/generated/numpy.einsum.html) operations into one. | ||
|
||
## Example | ||
|
||
Given two arrays: | ||
```python | ||
A = np.random.rand(32, 32, 10, 5) | ||
B = np.random.rand(32, 32, 10, 5) | ||
``` | ||
|
||
We frequently need to run multiple reshape/transpose/products/trace/etc., such as: | ||
```python | ||
C = np.einsum('ij...,kl...->ikjl...', A, B) | ||
D = C.reshape([2, ]*20 + [10, 5]) | ||
E = D.transpose([2, 3, 4, 5, 6, 7, 8, 9, 12, 13, 14, | ||
15, 16, 17, 18, 19, 0, 1, 10, 11, 20, 21]) | ||
F = E.reshape([256, 256, 4, 4, 10, 5]) | ||
X = np.trace(F) | ||
``` | ||
|
||
This obviously results in multiple intermediate arrays, some of which can be large. Instead of doing this, it is possible to combine multiple `np.einsum` operations into one. By carefully modifying the input shape, it is even possible to do this in cases in which the intermediate data is reshaped during the process, provided the shapes are all [compatible](#shape-compatibility). The previous example can instead be performed in a single `np.einsum` step: | ||
``` | ||
X = einsum_pipe( | ||
'ik...,jl...->ijkl...', | ||
[2, ]*20 + [10, 5], | ||
'abcde fghij klmno pqrst...->cde fghij mno pqrst ab kl...', | ||
[256, 256, 4, 4, 10, 5], | ||
'ii...', | ||
A, B | ||
) | ||
``` | ||
|
||
Internally, this calculates a compatible input shape, `(4, 8, 4, 8, 50)` and `(32, 32, 50)`, and a combined `np.einsum` set of subscripts, `"ebdbc,aac->edc"`. `A` and `B` are reshaped (which is generally essentially free), the single `np.einsum` operation is run, and the output is reshaped back to the expected output shape. | ||
|
||
## Syntax | ||
|
||
The syntax is based on Numpy's `einsum`, with the addition of allowing multiple subscripts and defining the shapes of the intermediate arrays. The input arrays can be put at the end, as shown, or next to the subscript definitions. In this example, only two arrays are used at start of the pipe, however you can add more arrays at later stages. The output of the previous step is always considered the first input of the subsequent step. | ||
|
||
## Shape Compatibility | ||
|
||
Shapes are compatible if each dimension is the product of some subsequence of a matching shape (of the previous output). For example, `(32, 32)` and `(4, 256)` are compatible, since both can be built from the shape `(4, 8, 4, 8)`: `(4*8, 4*8)` and `(4, 8*4*8)`. On the other hand, `(2, 3)` and `(3, 2)` aren't directly compatible since they don't share divisors. | ||
|
||
Note that transposition of axes also causes the transposition of the compatible shape, so while `[(3, 2), 'ij->ij', (2, 3)]` isn't valid, `[(3, 2), 'ij->ji', (2, 3)]` is. | ||
|
||
I plan to implement a best effort fallback which would reduce a sequence of operations to as few operations as possible, depending on incompatible shapes. | ||
|
||
## Numpy Operations |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
[build-system] | ||
requires = ["hatchling"] | ||
build-backend = "hatchling.build" | ||
|
||
[project] | ||
name = "einsum_pipe" | ||
version = "0.0.1" | ||
authors = [ | ||
{ name="David R. Armstrong" }, | ||
] | ||
description = "A Python package to compile multiple Numpy einsum operations into one" | ||
readme = "README.md" | ||
requires-python = ">=3.7" | ||
classifiers = [ | ||
"Programming Language :: Python :: 3", | ||
"License :: OSI Approved :: MIT License", | ||
"Operating System :: OS Independent", | ||
] | ||
dependencies = [ | ||
"numpy" | ||
] | ||
|
||
[project.urls] | ||
"Homepage" = "https://github.com/davystrong/Einsum-Pipe" | ||
"Bug Tracker" = "https://github.com/davystrong/Einsum-Pipe/issues" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
from .einsum_script import EinsumScript, EinsumComp | ||
from .einsum_pipe import einsum_pipe |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
from typing import Dict, Generic, Iterator, TypeVar | ||
from collections.abc import MutableMapping | ||
|
||
|
||
K = TypeVar('K') | ||
V = TypeVar('V') | ||
|
||
|
||
class BiDict(MutableMapping, Generic[K, V]): | ||
"""A custom dictionary that keeps track of the inverse mapping from values to keys. Values must be unique too | ||
""" | ||
|
||
def __init__(self): | ||
self.store: Dict[K, V] = {} | ||
self.inverse: Dict[V, K] = {} | ||
|
||
def __getitem__(self, key: K) -> V: | ||
return self.store[key] | ||
|
||
def __setitem__(self, key: K, value: V) -> None: | ||
if key in self: | ||
del self.inverse[self[key]] | ||
self.inverse[value] = key | ||
self.store[key] = value | ||
|
||
def __delitem__(self, key: K) -> None: | ||
del self.inverse[self[key]] | ||
del self.store[key] | ||
|
||
def values(self): | ||
return self.inverse.keys() | ||
|
||
def __iter__(self) -> Iterator[K]: | ||
return iter(self.store) | ||
|
||
def __len__(self): | ||
return len(self.store) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
from functools import reduce | ||
from typing import List | ||
|
||
import numpy as np | ||
from .einsum_script import EinsumScript | ||
|
||
|
||
def einsum_pipe(*args): | ||
subs = [arg for arg in args if isinstance(arg, (str, list, tuple))] | ||
ops = [arg for arg in args if not isinstance(arg, (str, list, tuple))] | ||
ops_index = 0 | ||
scripts: List[EinsumScript] = [] | ||
|
||
while len(subs) > 0: | ||
input_shapes = [] | ||
sub = subs.pop(0) | ||
if not isinstance(sub, str): | ||
input_shapes.append(sub) | ||
sub = subs.pop(0) | ||
assert isinstance(sub, str) | ||
|
||
args = sub.count(',') + 1 - len(input_shapes) | ||
|
||
input_shapes.extend(tuple(x.shape) | ||
for x in ops[ops_index:ops_index+args]) | ||
ops_index += args | ||
|
||
x = EinsumScript.parse(input_shapes, sub) | ||
scripts.append(x) | ||
|
||
output_script = reduce(lambda x, y: x+y, scripts) | ||
output_script.simplify() | ||
reshaped_ops = [np.reshape(op, [comp.size for comp in inp]) | ||
for op, inp in zip(ops, output_script.inputs)] | ||
raw_output: np.ndarray = np.einsum(str(output_script), *reshaped_ops) | ||
return raw_output.reshape([comp.size for comp in scripts[-1].outputs]) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,196 @@ | ||
import copy | ||
import math | ||
from typing import Generator, List, Self, Tuple, TypeVar, Union, cast | ||
from .bidict import BiDict | ||
|
||
|
||
class EinsumComp: | ||
def __init__(self, size: int) -> None: | ||
self.size = size | ||
|
||
|
||
class NullTag: | ||
pass | ||
|
||
|
||
class EinsumScript: | ||
def __init__(self, inputs: List[List[EinsumComp]], outputs: List[EinsumComp]) -> None: | ||
self.inputs = inputs | ||
self.outputs = outputs | ||
|
||
@classmethod | ||
def parse(cls, input_shapes: List[List[int]], subscripts: str) -> Self: | ||
subscripts = subscripts.replace(' ', '') | ||
# Easier to deal with broadcasting as a single character | ||
subscripts = subscripts.replace('...', '?') | ||
# The broadcasting character is automatically sorted to the start | ||
letters = sorted(subscripts.replace(',', '').replace('->', '')) | ||
if '->' not in subscripts: | ||
output_letters = [l for l in letters if l == | ||
'?' or letters.count(l) == 1] | ||
subscripts += '->' + ''.join(output_letters) | ||
letter_dict = {v: EinsumComp(0) for v in set(letters) if v != '?'} | ||
|
||
inputs_subs, output_subs = subscripts.split('->') | ||
inputs: List[List[EinsumComp]] = [] | ||
broadcast_comps: List[EinsumComp] = [] | ||
for sub, shape in zip(inputs_subs.split(','), input_shapes): | ||
inputs.append([]) | ||
for c in sub: | ||
if c == '?': | ||
# Broadcasting works from the last axis to the first and shares these axes with other broadcasts | ||
undefined_axes = len(shape) - (len(sub) - 1) | ||
for _ in range(undefined_axes - len(broadcast_comps)): | ||
broadcast_comps.insert(0, EinsumComp(0)) | ||
inputs[-1].extend(broadcast_comps[-undefined_axes:]) | ||
else: | ||
inputs[-1].append(letter_dict[c]) | ||
|
||
outputs: List[EinsumComp] = [] | ||
for c in output_subs: | ||
if c == '?': | ||
# All broadcasted axes are added in order | ||
outputs.extend(broadcast_comps) | ||
else: | ||
outputs.append(letter_dict[c]) | ||
|
||
script = EinsumScript(inputs, outputs) | ||
for inp, shape in zip(inputs, input_shapes): | ||
assert len(inp) == len(shape) | ||
for comp, dim in zip(inp, shape): | ||
comp.size = dim | ||
|
||
return script | ||
|
||
def split_comp(self, comp: EinsumComp, part_sizes: List[int]) -> None: | ||
repeats = [EinsumComp(size) for size in part_sizes[1:]] | ||
comp.size = part_sizes[0] | ||
for inp in [*self.inputs, self.outputs]: | ||
for i in range(len(inp)-1, -1, -1): | ||
if inp[i] == comp: | ||
for rep in repeats[::-1]: | ||
inp.insert(i+1, rep) | ||
|
||
def remove_ones(self): | ||
for inp in [*self.inputs, self.outputs]: | ||
for i in range(len(inp)-1, -1, -1): | ||
if inp[i].size == 1: | ||
inp.pop(i) | ||
|
||
def transform_shapes(self, input_shapes: List[List[int]]) -> List[int]: | ||
assert len(input_shapes) == len(self.inputs) | ||
shape_dict = {sub: comp for subs, shape in zip( | ||
self.inputs, input_shapes) for sub, comp in zip(subs, shape)} | ||
return [shape_dict[out_sub] for out_sub in self.outputs] | ||
|
||
def simplify(self): | ||
next_map: BiDict[Union[NullTag, EinsumComp], | ||
Union[NullTag, EinsumComp]] = BiDict() | ||
|
||
for comps in [*self.inputs, self.outputs]: | ||
prev = NullTag() | ||
for comp in comps: | ||
if prev in next_map: | ||
if next_map[prev] != comp: | ||
next_map[NullTag()] = next_map[prev] | ||
next_map[NullTag()] = comp | ||
next_map[prev] = NullTag() | ||
elif comp in next_map.values(): | ||
# Don't need to check if key is already the same as this will be caught by the previous condition | ||
key = next_map.inverse[comp] | ||
next_map[key] = NullTag() | ||
next_map[prev] = NullTag() | ||
next_map[NullTag()] = comp | ||
else: | ||
next_map[prev] = comp | ||
prev = comp | ||
next_map[prev] = NullTag() | ||
|
||
null_tags = [key for key in next_map if isinstance(key, NullTag)] | ||
group_pairs: List[Tuple[List[EinsumComp], EinsumComp]] = [] | ||
for tag in null_tags: | ||
seq: List[EinsumComp] = [] | ||
while not isinstance(next_map[tag], NullTag): | ||
seq.append(cast(EinsumComp, next_map[tag])) | ||
tag = next_map[tag] | ||
if len(seq) > 1: | ||
group_pairs.append( | ||
(seq, EinsumComp(math.prod(comp.size for comp in seq)))) | ||
|
||
for comps in [*self.inputs, self.outputs]: | ||
for group, new_comp in group_pairs: | ||
while group[0] in comps: | ||
i = comps.index(group[0]) | ||
comps[i] = new_comp | ||
for _ in range(len(group) - 1): | ||
comps.pop(i + 1) | ||
|
||
def simplified(self) -> Self: | ||
val = copy.deepcopy(self) | ||
val.simplify() | ||
return val | ||
|
||
@staticmethod | ||
def _get_char(index: int) -> str: | ||
return chr((ord('a') if index < 26 else (ord('A') - 26)) + index) | ||
|
||
def __str__(self) -> str: | ||
comps = list(set(comp for inp in self.inputs for comp in inp)) | ||
|
||
subs = [] | ||
for inp in self.inputs: | ||
subs.append(''.join(self._get_char(comps.index(comp)) | ||
for comp in inp)) | ||
|
||
output_str = ''.join(self._get_char(comps.index(comp)) | ||
for comp in self.outputs) | ||
|
||
return ','.join(subs) + '->' + output_str | ||
|
||
def __add__(self, rhs: Self) -> Self: | ||
lhs = copy.deepcopy(self) | ||
rhs = copy.deepcopy(rhs) | ||
lhs_out_iter = rev_mut_iter(lhs.outputs) | ||
rhs_in_iter = rev_mut_iter(rhs.inputs[0]) | ||
|
||
lhs_out_val = next(lhs_out_iter) | ||
rhs_in_val = next(rhs_in_iter) | ||
|
||
try: | ||
while True: | ||
if lhs_out_val.size == rhs_in_val.size: | ||
lhs_out_val = next(lhs_out_iter) | ||
rhs_in_val = next(rhs_in_iter) | ||
elif lhs_out_val.size > rhs_in_val.size: | ||
lhs.split_comp(lhs_out_val, [ | ||
lhs_out_val.size // rhs_in_val.size, rhs_in_val.size]) | ||
rhs_in_val = next(rhs_in_iter) | ||
else: | ||
rhs.split_comp(rhs_in_val, [ | ||
rhs_in_val.size // lhs_out_val.size, lhs_out_val.size]) | ||
lhs_out_val = next(lhs_out_iter) | ||
except StopIteration: | ||
pass | ||
|
||
assert len(lhs.outputs) == len(rhs.inputs[0]) | ||
assert all(x.size == y.size for x, y in zip( | ||
lhs.outputs, rhs.inputs[0])) | ||
|
||
for i, x in enumerate(rhs.inputs[0]): | ||
val = lhs.outputs[i] | ||
lhs.outputs[i] = x | ||
for inp in lhs.inputs: | ||
if val in inp: | ||
for j, y in enumerate(inp): | ||
if y == val: | ||
inp[j] = x | ||
|
||
return EinsumScript(lhs.inputs + rhs.inputs[1:], rhs.outputs) | ||
|
||
|
||
T = TypeVar('T') | ||
|
||
|
||
def rev_mut_iter(data: List[T]) -> Generator[T, None, None]: | ||
for i in range(len(data)-1, -1, -1): | ||
yield data[i] |