Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Finalizes typing and updates documentation #232

Merged
merged 12 commits into from
May 27, 2024
Merged
14 changes: 9 additions & 5 deletions docs/api_reference.md
Original file line number Diff line number Diff line change
@@ -1,18 +1,22 @@
---
toc_depth: 1
---

# API Documentation

### `opt_einsum.contract`

::: opt_einsum.contract
::: opt_einsum.contract.contract
<!-- :docstring: -->

### `opt_einsum.contract_path`

::: opt_einsum.contract_path
::: opt_einsum.contract.contract_path
<!-- :docstring: -->

### `opt_einsum.contract_expression`

::: opt_einsum.contract_expression
::: opt_einsum.contract.contract_expression
<!-- :docstring:
:members: -->

Expand All @@ -29,12 +33,12 @@

### `opt_einsum.get_symbol`

::: opt_einsum.get_symbol
::: opt_einsum.parser.get_symbol
<!-- :docstring: -->

### `opt_einsum.shared_intermediates`

::: opt_einsum.shared_intermediates
::: opt_einsum.sharing.shared_intermediates
<!-- :docstring: -->

### `opt_einsum.paths.optimal`
Expand Down
6 changes: 3 additions & 3 deletions docs/getting_started/backends.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ The following is a brief overview of libraries which have been tested with
The automatic backend detection will be detected based on the first supplied
array (default), this can be overridden by specifying the correct `backend`
argument for the type of arrays supplied when calling
[`opt_einsum.contract`](../api_reference.md##opt_einsumcontract). For example, if you had a library installed
[`opt_einsum.contract`](../api_reference.md#opt_einsum.contract.contract). For example, if you had a library installed
called `'foo'` which provided an `numpy.ndarray` like object with a
`.shape` attribute as well as `foo.tensordot` and `foo.transpose` then
you could contract them with something like:
Expand Down Expand Up @@ -189,7 +189,7 @@ Currently `opt_einsum` can handle this automatically for:

all of which offer GPU support. Since `tensorflow` and `theano` both require
compiling the expression, this functionality is encapsulated in generating a
[`opt_einsum.ContractExpression`](../api_reference.md#opt_einsumcontractcontractexpression) using
[`opt_einsum.ContractExpression`](../api_reference.md#opt_einsum.contract.contract_expression) using
[`opt_einsum.contract_expression`](../api_reference.md#opt_einsumcontract_expression), which can then be called using numpy
arrays whilst specifying `backend='tensorflow'` etc.
Additionally, if arrays are marked as `constant`
Expand Down Expand Up @@ -259,7 +259,7 @@ tf.enable_eager_execution()

After which `opt_einsum` will automatically detect eager mode if
`backend='tensorflow'` is supplied to a
[`opt_einsum.ContractExpression`](../api_reference.md###opt_einsumcontractcontractexpression).
[`opt_einsum.ContractExpression`](../api_reference.md#opt_einsum.contract.contract_expression).


### Pytorch & Cupy
Expand Down
4 changes: 2 additions & 2 deletions docs/paths/custom_paths.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Custom Path Optimizers

If you want to implement or just experiment with custom contaction paths then
you can easily by subclassing the [`opt_einsum.paths.PathOptimizer`](../api_reference.md#opt_einsumpathspathoptimizer)
you can easily by subclassing the [`opt_einsum.paths.PathOptimizer`](../api_reference.md#opt_einsum.paths.PathOptimizer)
object. For example, imagine we want to test the path that just blindly
contracts the first pair of tensors again and again. We would implement this
as:
Expand Down Expand Up @@ -49,7 +49,7 @@ machinery of the random-greedy approach. Namely:
- Parallelization using a pool-executor

This is done by subclassing the
[`opt_einsum.paths.RandomOptimizer`](../api_reference.md#opt_einsumpathsrandomoptimizer)
[`opt_einsum.paths.RandomOptimizer`](../api_reference.md#opt_einsum.path_random.RandomOptimizer)
object and implementing a
`setup` method. Here's an example where we just randomly select any path
(again, although we get a considerable speedup over `einsum` this is
Expand Down
22 changes: 0 additions & 22 deletions docs/reference/api.rst

This file was deleted.

45 changes: 45 additions & 0 deletions mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,30 @@ repo_url: https://github.com/dgasmith/opt_einsum
repo_name: dgasmith/opt_einsum
theme:
name: material
features:
- navigation.instant
palette:

# Palette toggle for automatic mode
- media: "(prefers-color-scheme)"
toggle:
icon: material/brightness-auto
name: Switch to light mode

# Palette toggle for light mode
- media: "(prefers-color-scheme: light)"
scheme: default
toggle:
icon: material/brightness-7
name: Switch to dark mode

# Palette toggle for dark mode
- media: "(prefers-color-scheme: dark)"
scheme: slate
toggle:
icon: material/brightness-4
name: Switch to system preference


plugins:
- search
Expand All @@ -14,6 +38,25 @@ plugins:
# paths: [opt_einsum]
options:
docstring_style: google
docstring_options:
ignore_init_summary: true
docstring_section_style: list
filters: ["!^_"]
heading_level: 1
inherited_members: true
merge_init_into_class: true
parameter_headings: true
preload_modules: [mkdocstrings]
separate_signature: true
show_root_heading: true
show_root_full_path: false
show_signature_annotations: true
show_source: false
show_symbol_type_heading: true
show_symbol_type_toc: true
signature_crossrefs: true
summary: true
unwrap_annotated: true

extra_javascript:
- javascript/config.js
Expand All @@ -36,6 +79,8 @@ markdown_extensions:
- pymdownx.extra
- pymdownx.arithmatex:
generic: true
- toc:
toc_depth: 2

nav:
- Overview: index.md
Expand Down
14 changes: 7 additions & 7 deletions opt_einsum/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,15 @@
Main init function for opt_einsum.
"""

from . import blas, helpers, path_random, paths
from .contract import contract, contract_expression, contract_path
from .parser import get_symbol
from .path_random import RandomGreedy
from .paths import BranchBound, DynamicProgramming
from .sharing import shared_intermediates
from opt_einsum import blas, helpers, path_random, paths
from opt_einsum.contract import contract, contract_expression, contract_path
from opt_einsum.parser import get_symbol
from opt_einsum.path_random import RandomGreedy
from opt_einsum.paths import BranchBound, DynamicProgramming
from opt_einsum.sharing import shared_intermediates

# Handle versioneer
from ._version import get_versions # isort:skip
from opt_einsum._version import get_versions # isort:skip

versions = get_versions()
__version__ = versions["version"]
Expand Down
4 changes: 3 additions & 1 deletion opt_einsum/backends/object_arrays.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,10 @@

import numpy as np

from opt_einsum.typing import ArrayType

def object_einsum(eq, *arrays):

def object_einsum(eq: str, *arrays: ArrayType) -> ArrayType:
"""A ``einsum`` implementation for ``numpy`` arrays with object dtype.
The loop is performed in python, meaning the objects themselves need
only to implement ``__mul__`` and ``__add__`` for the contraction to be
Expand Down
2 changes: 1 addition & 1 deletion opt_einsum/backends/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def transpose(a, axes):
return a.permute(*axes)


def einsum(equation, *operands):
def einsum(equation, *operands, **kwargs):
"""Variadic version of torch.einsum to match numpy api."""
# rename symbols to support PyTorch 0.4.1 and earlier,
# which allow only symbols a-z.
Expand Down
Loading
Loading