Skip to content

Commit

Permalink
Replaced simplify logic
Browse files Browse the repository at this point in the history
Added more debug code
  • Loading branch information
davystrong committed May 25, 2023
1 parent 5c31744 commit c24e814
Show file tree
Hide file tree
Showing 4 changed files with 56 additions and 84 deletions.
37 changes: 0 additions & 37 deletions src/einsum_pipe/bidict.py

This file was deleted.

1 change: 1 addition & 0 deletions src/einsum_pipe/einsum_pipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ def _collapse(scripts: List[EinsumScript]) -> List[EinsumScript]:
splits: List[EinsumScript] = [scripts[0]]
output_sizes: List[int] = []
for i, script in enumerate(scripts[1:], 1):
script.simplify()
try:
output_sizes.append(math.prod(splits[-1].output_shape))
splits[-1] += script
Expand Down
94 changes: 47 additions & 47 deletions src/einsum_pipe/einsum_script.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,13 @@
from __future__ import annotations
import copy
import math
from typing import Generator, List, Optional, Tuple, TypeVar, Union, cast
from .bidict import _BiDict
from typing import Generator, List, NamedTuple, Optional, Tuple, TypeVar


class EinsumComp:
def __init__(self, size: int) -> None:
def __init__(self, size: int, parsed_from: Optional[str] = None) -> None:
self.size = size


class NullTag:
pass
self._parsed_from = parsed_from


class IncompatibleShapeError(Exception):
Expand All @@ -36,16 +32,16 @@ def parse(cls, input_shapes: List[List[int]], subscripts: str) -> EinsumScript:
parsed_subscripts = subscripts
subscripts = subscripts.replace(' ', '')
# Easier to deal with broadcasting as a single character
subscripts = subscripts.replace('...', '?')
subscripts = subscripts.replace('...', '.')
# The broadcasting character is automatically sorted to the start
letters = sorted(subscripts.replace(',', '').replace('->', ''))
if '->' not in subscripts:
output_letters = [l for l in letters if l ==
'?' or letters.count(l) == 1]
if (bc_count := output_letters.count('?')) > 1:
'.' or letters.count(l) == 1]
if (bc_count := output_letters.count('.')) > 1:
output_letters = output_letters[bc_count - 1:]
subscripts += '->' + ''.join(output_letters)
letter_dict = {v: EinsumComp(0) for v in set(letters) if v != '?'}
letter_dict = {v: EinsumComp(0, v) for v in set(letters) if v != '.'}

inputs_subs, output_subs = subscripts.split('->')
inputs: List[List[EinsumComp]] = []
Expand All @@ -55,11 +51,11 @@ def parse(cls, input_shapes: List[List[int]], subscripts: str) -> EinsumScript:
for sub, shape in zip(inputs_subs.split(','), input_shapes):
inputs.append([])
for c in sub:
if c == '?':
if c == '.':
# Broadcasting works from the last axis to the first and shares these axes with other broadcasts
undefined_axes = len(shape) - (len(sub) - 1)
for _ in range(undefined_axes - len(broadcast_comps)):
broadcast_comps.insert(0, EinsumComp(0))
broadcast_comps.insert(0, EinsumComp(0, '...'))
if undefined_axes > 0:
inputs[-1].extend(broadcast_comps[-undefined_axes:])
else:
Expand All @@ -69,7 +65,7 @@ def parse(cls, input_shapes: List[List[int]], subscripts: str) -> EinsumScript:

outputs: List[EinsumComp] = []
for c in output_subs:
if c == '?':
if c == '.':
# All broadcasted axes are added in order
outputs.extend(broadcast_comps)
else:
Expand All @@ -86,7 +82,8 @@ def parse(cls, input_shapes: List[List[int]], subscripts: str) -> EinsumScript:
return script

def split_comp(self, comp: EinsumComp, part_sizes: List[int]) -> None:
repeats = [EinsumComp(size) for size in part_sizes[1:]]
repeats = [EinsumComp(size, comp._parsed_from)
for size in part_sizes[1:]]
comp.size = part_sizes[0]
for inp in [*self.inputs, self.outputs]:
for i in range(len(inp)-1, -1, -1):
Expand All @@ -110,39 +107,34 @@ def output_shape(self) -> Tuple[int]:

def simplify(self):
# Get sequences (repeated or not) in which each element is unique to the sequence
# Basically, run through it as something like a linked list
next_map: _BiDict[Union[NullTag, EinsumComp],
Union[NullTag, EinsumComp]] = _BiDict()

# Might be more efficient with some sort of linked list, but doesn't matter
seqs: List[Optional[EinsumComp]] = []
for comps in [*self.inputs, self.outputs]:
prev = NullTag()
for comp in comps:
if prev in next_map:
if next_map[prev] != comp:
next_map[NullTag()] = next_map[prev]
next_map[NullTag()] = comp
next_map[prev] = NullTag()
elif comp in next_map.values():
# Don't need to check if key is already the same as this will be caught by the previous condition
key = next_map.inverse[comp]
next_map[key] = NullTag()
next_map[prev] = NullTag()
next_map[NullTag()] = comp
for prev_comp, comp, next_comp in zip([None, *comps[:-1]], comps, [*comps[1:], None]):
if comp not in seqs:
seqs.append(comp)
else:
next_map[prev] = comp
prev = comp
next_map[prev] = NullTag()

null_tags = [key for key in next_map if isinstance(key, NullTag)]
group_pairs: List[Tuple[List[EinsumComp], EinsumComp]] = []
for tag in null_tags:
seq: List[EinsumComp] = []
while not isinstance(next_map[tag], NullTag):
seq.append(cast(EinsumComp, next_map[tag]))
tag = next_map[tag]
if len(seq) > 1:
group_pairs.append(
(seq, EinsumComp(math.prod(comp.size for comp in seq))))
seqs.append(None)
i = seqs.index(comp)
if i == len(seqs) - 1 or seqs[i + 1] != next_comp:
seqs.insert(i + 1, None)
if i == 0 or seqs[i - 1] != prev_comp:
seqs.insert(i, None)
seqs.append(None)

groups: List[List[EinsumComp]] = [[]]
for comp in seqs:
if comp is None:
groups.append([])
else:
groups[-1].append(comp)

group_pairs = [(group, EinsumComp(math.prod(comp.size for comp in group), ''.join(
comp._parsed_from or '.' for comp in group))) for group in groups if len(group) > 1]

# To check if the sizes before reshaping are the same as the sizes after
sizes_before = [math.prod(shape) for shape in [
*self.input_shapes, self.output_shape]]

# Replace sequences of comps with their respective group comp
for comps in [*self.inputs, self.outputs]:
Expand All @@ -153,6 +145,11 @@ def simplify(self):
for _ in range(len(group) - 1):
comps.pop(i + 1)

sizes_after = [math.prod(shape) for shape in [
*self.input_shapes, self.output_shape]]
assert all(before == after for before,
after in zip(sizes_before, sizes_after)), 'This is a bug. Please submit a bug report!'

def simplified(self) -> EinsumScript:
val = copy.deepcopy(self)
val.simplify()
Expand Down Expand Up @@ -190,7 +187,10 @@ def __repr__(self) -> str:
return f'"{self}" (parsed from "{self._parsed_script}")'

def __str__(self) -> str:
comps = list(set(comp for inp in self.inputs for comp in inp))
# This is equivalent to using a set except that it preserves order (at least since Python 3.7)
# This isn't required but produces more natural string outputs
comps: List[EinsumComp] = list(dict.fromkeys(
comp for inp in self.inputs for comp in inp).keys())

subs = []
for inp in self.inputs:
Expand Down
8 changes: 8 additions & 0 deletions tests/test_einsum_pipe.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import numpy as np
from einsum_pipe import einsum_pipe, compile_einsum_args
from einsum_pipe.einsum_script import EinsumScript


def einsum_pipe_simple(*args):
Expand Down Expand Up @@ -124,6 +125,13 @@ def test_implicit_broadcast():
assert np.allclose(einsum_pipe(*args), einsum_pipe_simple(*args))


def test_basic_simplify():
script_a = EinsumScript.parse([[10, 20, 5, 6]], 'abcd->acd')
script_a.simplify()
script_b = EinsumScript.parse([[10, 20, 30]], 'abc->ac')
assert str(script_a) == str(script_b)


def test_simplify():
# Make a discontiguous array
A = np.random.rand(3, 3, 3, 9).transpose((0, 2, 1, 3))
Expand Down

0 comments on commit c24e814

Please sign in to comment.