Skip to content

Commit

Permalink
indexing types
Browse files Browse the repository at this point in the history
  • Loading branch information
juanitorduz committed Nov 17, 2024
1 parent f6af157 commit 9bae35c
Showing 1 changed file with 10 additions and 6 deletions.
16 changes: 10 additions & 6 deletions numpyro/ops/indexing.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,18 @@
# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0

from typing import Any

from jax import Array
import jax.numpy as jnp
from jax.typing import ArrayLike


def _is_batched(arg):
return jnp.ndim(arg) > 0


def vindex(tensor, args):
def vindex(tensor: ArrayLike, args: tuple[Any, ...]) -> Array:
"""
Vectorized advanced indexing with broadcasting semantics.
Expand Down Expand Up @@ -72,10 +76,10 @@ def vindex(tensor, args):
This implementation is similar to the proposed notation
``x.vindex[]`` except for slightly different handling of ``Ellipsis``.
:param jnp.ndarray tensor: A tensor to be indexed.
:param tuple args: An index, as args to ``__getitem__``.
:param ArrayLike tensor: A tensor to be indexed.
:param tuple[Any, ...] args: An index, as args to ``__getitem__``.
:returns: A nonstandard interpretation of ``tensor[args]``.
:rtype: jnp.ndarray
:rtype: Array
"""
if not isinstance(args, tuple):
return tensor[args]
Expand Down Expand Up @@ -140,8 +144,8 @@ class Vindex:
:return: An object with a special :meth:`__getitem__` method.
"""

def __init__(self, tensor):
def __init__(self, tensor: ArrayLike) -> None:
self._tensor = tensor

def __getitem__(self, args):
def __getitem__(self, args: tuple[Any, ...]) -> Array:
return vindex(self._tensor, args)

0 comments on commit 9bae35c

Please sign in to comment.