Skip to content

Commit

Permalink
Format and sort imports, run ruff rules
Browse files Browse the repository at this point in the history
  • Loading branch information
agriyakhetarpal committed Aug 26, 2024
1 parent 68e3dda commit 0f7bce4
Show file tree
Hide file tree
Showing 96 changed files with 288 additions and 277 deletions.
29 changes: 15 additions & 14 deletions autograd/__init__.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,24 @@
from autograd.core import primitive_with_deprecation_warnings as primitive

from .builtins import dict, isinstance, list, tuple, type
from .differential_operators import (
make_vjp,
grad,
multigrad_dict,
checkpoint,
deriv,
elementwise_grad,
value_and_grad,
grad,
grad_and_aux,
grad_named,
hessian,
hessian_tensor_product,
hessian_vector_product,
hessian,
holomorphic_grad,
jacobian,
tensor_jacobian_product,
vector_jacobian_product,
grad_named,
checkpoint,
make_ggnvp,
make_hvp,
make_jvp,
make_ggnvp,
deriv,
holomorphic_grad,
make_vjp,
multigrad_dict,
tensor_jacobian_product,
value_and_grad,
vector_jacobian_product,
)
from .builtins import isinstance, type, tuple, list, dict
from autograd.core import primitive_with_deprecation_warnings as primitive
14 changes: 7 additions & 7 deletions autograd/builtins.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
from .util import subvals
from .extend import (
Box,
primitive,
notrace_primitive,
VSpace,
vspace,
SparseObject,
defvjp,
defvjp_argnum,
VSpace,
defjvp,
defjvp_argnum,
defvjp,
defvjp_argnum,
notrace_primitive,
primitive,
vspace,
)
from .util import subvals

isinstance_ = isinstance
isinstance = notrace_primitive(isinstance)
Expand Down
9 changes: 4 additions & 5 deletions autograd/core.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from itertools import count
from functools import reduce
from .tracer import trace, primitive, toposort, Node, Box, isbox, getval
from itertools import count

from .tracer import Box, Node, getval, isbox, primitive, toposort, trace
from .util import func, subval

# -------------------- reverse mode --------------------
Expand Down Expand Up @@ -40,9 +41,7 @@ def __init__(self, value, fun, args, kwargs, parent_argnums, parents):
vjpmaker = primitive_vjps[fun]
except KeyError:
fun_name = getattr(fun, "__name__", fun)
raise NotImplementedError(
f"VJP of {fun_name} wrt argnums {parent_argnums} not defined"
)
raise NotImplementedError(f"VJP of {fun_name} wrt argnums {parent_argnums} not defined")
self.vjp = vjpmaker(parent_argnums, value, args, kwargs)

def initialize_root(self):
Expand Down
12 changes: 6 additions & 6 deletions autograd/differential_operators.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
"""Convenience functions built on top of `make_vjp`."""

from functools import partial
from collections import OrderedDict

try:
Expand All @@ -9,13 +8,14 @@
from inspect import getargspec as _getargspec # Python 2
import warnings

from .wrap_util import unary_to_nary
from .builtins import tuple as atuple
from .core import make_vjp as _make_vjp, make_jvp as _make_jvp
from .extend import primitive, defvjp_argnum, vspace

import autograd.numpy as np

from .builtins import tuple as atuple
from .core import make_jvp as _make_jvp
from .core import make_vjp as _make_vjp
from .extend import defvjp_argnum, primitive, vspace
from .wrap_util import unary_to_nary

make_vjp = unary_to_nary(_make_vjp)
make_jvp = unary_to_nary(_make_jvp)

Expand Down
20 changes: 10 additions & 10 deletions autograd/extend.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
# Exposes API for extending autograd
from .tracer import Box, primitive, register_notrace, notrace_primitive
from .core import (
JVPNode,
SparseObject,
VSpace,
vspace,
VJPNode,
JVPNode,
defvjp_argnums,
defvjp_argnum,
defvjp,
defjvp_argnums,
defjvp_argnum,
defjvp,
VSpace,
def_linear,
defjvp,
defjvp_argnum,
defjvp_argnums,
defvjp,
defvjp_argnum,
defvjp_argnums,
vspace,
)
from .tracer import Box, notrace_primitive, primitive, register_notrace
2 changes: 1 addition & 1 deletion autograd/misc/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
from .tracers import const_graph
from .flatten import flatten
from .tracers import const_graph
4 changes: 2 additions & 2 deletions autograd/misc/fixed_points.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from autograd.extend import primitive, defvjp, vspace
from autograd.builtins import tuple
from autograd import make_vjp
from autograd.builtins import tuple
from autograd.extend import defvjp, primitive, vspace


@primitive
Expand Down
2 changes: 1 addition & 1 deletion autograd/misc/flatten.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
arrays. The main purpose is to make examples and optimizers simpler.
"""

import autograd.numpy as np
from autograd import make_vjp
from autograd.builtins import type
import autograd.numpy as np


def flatten(value):
Expand Down
1 change: 0 additions & 1 deletion autograd/misc/optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
These routines can optimize functions whose inputs are structured
objects, such as dicts of numpy arrays."""


import autograd.numpy as np
from autograd.misc import flatten
from autograd.wrap_util import wraps
Expand Down
7 changes: 4 additions & 3 deletions autograd/misc/tracers.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from functools import partial
from itertools import repeat
from autograd.wrap_util import wraps

from autograd.tracer import Node, trace
from autograd.util import subvals, toposort
from autograd.tracer import trace, Node
from functools import partial
from autograd.wrap_util import wraps


class ConstGraphNode(Node):
Expand Down
8 changes: 1 addition & 7 deletions autograd/numpy/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,2 @@
from . import fft, linalg, numpy_boxes, numpy_jvps, numpy_vjps, numpy_vspaces, random
from .numpy_wrapper import *
from . import numpy_boxes
from . import numpy_vspaces
from . import numpy_vjps
from . import numpy_jvps
from . import linalg
from . import fft
from . import random
8 changes: 5 additions & 3 deletions autograd/numpy/fft.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import numpy.fft as ffto
from .numpy_wrapper import wrap_namespace
from .numpy_vjps import match_complex

from autograd.extend import defvjp, primitive, vspace

from . import numpy_wrapper as anp
from autograd.extend import primitive, defvjp, vspace
from .numpy_vjps import match_complex
from .numpy_wrapper import wrap_namespace

wrap_namespace(ffto.__dict__, globals())

Expand Down
7 changes: 5 additions & 2 deletions autograd/numpy/linalg.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
from functools import partial

import numpy.linalg as npla
from .numpy_wrapper import wrap_namespace

from autograd.extend import defjvp, defvjp

from . import numpy_wrapper as anp
from autograd.extend import defvjp, defjvp
from .numpy_wrapper import wrap_namespace

wrap_namespace(npla.__dict__, globals())

Expand Down
4 changes: 3 additions & 1 deletion autograd/numpy/numpy_boxes.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import numpy as np
from autograd.extend import Box, primitive

from autograd.builtins import SequenceBox
from autograd.extend import Box, primitive

from . import numpy_wrapper as anp

Box.__array_priority__ = 90.0
Expand Down
16 changes: 9 additions & 7 deletions autograd/numpy/numpy_jvps.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,21 @@
import numpy as onp

from autograd.extend import JVPNode, def_linear, defjvp, defjvp_argnum, register_notrace, vspace

from ..util import func
from . import numpy_wrapper as anp
from .numpy_boxes import ArrayBox
from .numpy_vjps import (
untake,
balanced_eq,
match_complex,
replace_zero,
dot_adjoint_0,
dot_adjoint_1,
match_complex,
nograd_functions,
replace_zero,
tensordot_adjoint_0,
tensordot_adjoint_1,
nograd_functions,
untake,
)
from autograd.extend import defjvp, defjvp_argnum, def_linear, vspace, JVPNode, register_notrace
from ..util import func
from .numpy_boxes import ArrayBox

for fun in nograd_functions:
register_notrace(JVPNode, fun)
Expand Down
5 changes: 4 additions & 1 deletion autograd/numpy/numpy_vjps.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
from functools import partial

import numpy as onp

from autograd.extend import SparseObject, VJPNode, defvjp, defvjp_argnum, primitive, register_notrace, vspace

from ..util import func
from . import numpy_wrapper as anp
from .numpy_boxes import ArrayBox
from autograd.extend import primitive, vspace, defvjp, defvjp_argnum, SparseObject, VJPNode, register_notrace

# ----- Non-differentiable functions -----

Expand Down
3 changes: 2 additions & 1 deletion autograd/numpy/numpy_vspaces.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import numpy as np
from autograd.extend import VSpace

from autograd.builtins import NamedTupleVSpace
from autograd.extend import VSpace


class ArrayVSpace(VSpace):
Expand Down
9 changes: 4 additions & 5 deletions autograd/numpy/numpy_wrapper.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import types
import warnings
from autograd.extend import primitive, notrace_primitive

import numpy as _np

import autograd.builtins as builtins
from autograd.extend import notrace_primitive, primitive

if _np.lib.NumpyVersion(_np.__version__) >= "2.0.0":
from numpy._core.einsumfunc import _parse_einsum_input
Expand Down Expand Up @@ -75,9 +76,7 @@ def array(A, *args, **kwargs):
def wrap_if_boxes_inside(raw_array, slow_op_name=None):
if raw_array.dtype is _np.dtype("O"):
if slow_op_name:
warnings.warn(
"{} is slow for array inputs. " "np.concatenate() is faster.".format(slow_op_name)
)
warnings.warn("{} is slow for array inputs. " "np.concatenate() is faster.".format(slow_op_name))
return array_from_args((), {}, *raw_array.ravel()).reshape(raw_array.shape)
else:
return raw_array
Expand Down
1 change: 1 addition & 0 deletions autograd/numpy/random.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import numpy.random as npr

from .numpy_wrapper import wrap_namespace

wrap_namespace(npr.__dict__, globals())
5 changes: 1 addition & 4 deletions autograd/scipy/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,4 @@
from . import integrate
from . import signal
from . import special
from . import stats
from . import integrate, signal, special, stats

try:
from . import misc
Expand Down
4 changes: 2 additions & 2 deletions autograd/scipy/integrate.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import scipy.integrate

import autograd.numpy as np
from autograd.extend import primitive, defvjp_argnums
from autograd import make_vjp
from autograd.misc import flatten
from autograd.builtins import tuple
from autograd.extend import defvjp_argnums, primitive
from autograd.misc import flatten

odeint = primitive(scipy.integrate.odeint)

Expand Down
3 changes: 2 additions & 1 deletion autograd/scipy/linalg.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from functools import partial

import scipy.linalg

import autograd.numpy as anp
from autograd.extend import defjvp, defjvp_argnums, defvjp, defvjp_argnums
from autograd.numpy.numpy_wrapper import wrap_namespace
from autograd.extend import defvjp, defvjp_argnums, defjvp, defjvp_argnums

wrap_namespace(scipy.linalg.__dict__, globals()) # populates module namespace

Expand Down
1 change: 1 addition & 0 deletions autograd/scipy/misc.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import scipy.misc as osp_misc

from ..scipy import special

if hasattr(osp_misc, "logsumexp"):
Expand Down
7 changes: 4 additions & 3 deletions autograd/scipy/signal.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from functools import partial
import autograd.numpy as np
import numpy as npo # original numpy
from autograd.extend import primitive, defvjp

import numpy as npo # original numpy
from numpy.lib.stride_tricks import as_strided

import autograd.numpy as np
from autograd.extend import defvjp, primitive


@primitive
def convolve(A, B, axes=None, dot_axes=[(), ()], mode="full"):
Expand Down
5 changes: 3 additions & 2 deletions autograd/scipy/special.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import scipy.special

import autograd.numpy as np
from autograd.extend import primitive, defvjp, defjvp
from autograd.numpy.numpy_vjps import unbroadcast_f, repeat_to_match_shape
from autograd.extend import defjvp, defvjp, primitive
from autograd.numpy.numpy_vjps import repeat_to_match_shape, unbroadcast_f

### Beta function ###
beta = primitive(scipy.special.beta)
Expand Down
7 changes: 1 addition & 6 deletions autograd/scipy/stats/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,4 @@
from . import chi2
from . import beta
from . import gamma
from . import norm
from . import poisson
from . import t
from . import beta, chi2, gamma, norm, poisson, t

# Try block needed in case the user has an
# old version of scipy without multivariate normal.
Expand Down
5 changes: 3 additions & 2 deletions autograd/scipy/stats/beta.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import autograd.numpy as np
import scipy.stats
from autograd.extend import primitive, defvjp

import autograd.numpy as np
from autograd.extend import defvjp, primitive
from autograd.numpy.numpy_vjps import unbroadcast_f
from autograd.scipy.special import beta, psi

Expand Down
Loading

0 comments on commit 0f7bce4

Please sign in to comment.