Skip to content

Commit

Permalink
Support Python 3.9 (#1923)
Browse files Browse the repository at this point in the history
* python 3.9 lint

* support python 3.9 lint

* extend python versions

* extend tests matrix: modeling and inference

* use 3.12 for a moment
  • Loading branch information
juanitorduz authored Dec 2, 2024
1 parent 07d5f2e commit a4fd59b
Show file tree
Hide file tree
Showing 6 changed files with 21 additions and 18 deletions.
6 changes: 3 additions & 3 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.10"]
python-version: ["3.9","3.10","3.12"]

steps:
- uses: actions/checkout@v2
Expand Down Expand Up @@ -50,7 +50,7 @@ jobs:
needs: lint
strategy:
matrix:
python-version: ["3.10"]
python-version: ["3.9", "3.10"]

steps:
- uses: actions/checkout@v2
Expand Down Expand Up @@ -83,7 +83,7 @@ jobs:
needs: lint
strategy:
matrix:
python-version: ["3.10"]
python-version: ["3.9", "3.10"]

steps:
- uses: actions/checkout@v2
Expand Down
4 changes: 2 additions & 2 deletions numpyro/contrib/control_flow/scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from collections import OrderedDict
from functools import partial
from typing import Callable
from typing import Callable, Optional

import jax
from jax import device_put, lax, random
Expand Down Expand Up @@ -348,7 +348,7 @@ def scan(
f: Callable,
init,
xs,
length: int | None = None,
length: Optional[int] = None,
reverse: bool = False,
history: int = 1,
):
Expand Down
12 changes: 6 additions & 6 deletions numpyro/contrib/stochastic_support/dcc.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from abc import ABC, abstractmethod
from collections import OrderedDict, namedtuple
from typing import Any, Callable, OrderedDict as OrderedDictType
from typing import Any, Callable, OrderedDict as OrderedDictType, Union

import jax
from jax import random
Expand All @@ -21,9 +21,9 @@

SDVIResult = namedtuple("SDVIResult", ["guides", "slp_weights"])

RunInferenceResult = (
dict[str, Any] | tuple[AutoGuide, dict[str, Any]]
) # for mcmc or sdvi
RunInferenceResult = Union[
dict[str, Any], tuple[AutoGuide, dict[str, Any]]
] # for mcmc or sdvi


class StochasticSupportInference(ABC):
Expand Down Expand Up @@ -124,12 +124,12 @@ def _combine_inferences(
branching_traces: dict[str, OrderedDictType],
*args: Any,
**kwargs: Any,
) -> DCCResult | SDVIResult:
) -> Union[DCCResult, SDVIResult]:
raise NotImplementedError

def run(
self, rng_key: ArrayLike, *args: Any, **kwargs: Any
) -> DCCResult | SDVIResult:
) -> Union[DCCResult, SDVIResult]:
"""
Run inference on each SLP separately and combine the results.
Expand Down
5 changes: 3 additions & 2 deletions numpyro/diagnostics.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from collections import OrderedDict
from itertools import product
from typing import Union

import numpy as np

Expand Down Expand Up @@ -230,7 +231,7 @@ def hpdi(x: np.ndarray, prob: float = 0.90, axis: int = 0) -> np.ndarray:


def summary(
samples: dict | np.ndarray, prob: float = 0.90, group_by_chain: bool = True
samples: Union[dict, np.ndarray], prob: float = 0.90, group_by_chain: bool = True
) -> dict:
"""
Returns a summary table displaying diagnostics of ``samples`` from the
Expand Down Expand Up @@ -284,7 +285,7 @@ def summary(


def print_summary(
samples: dict | np.ndarray, prob: float = 0.90, group_by_chain: bool = True
samples: Union[dict, np.ndarray], prob: float = 0.90, group_by_chain: bool = True
) -> None:
"""
Prints a summary table displaying diagnostics of ``samples`` from the
Expand Down
10 changes: 5 additions & 5 deletions numpyro/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import random
import re
from threading import Lock
from typing import Any, Callable, Generator
from typing import Any, Callable, Generator, Optional
import warnings

import numpy as np
Expand All @@ -27,7 +27,7 @@
_CHAIN_RE = re.compile(r"\d+$") # e.g. get '3' from 'TFRT_CPU_3'


def set_rng_seed(rng_seed: int | None = None) -> None:
def set_rng_seed(rng_seed: Optional[int] = None) -> None:
"""
Initializes internal state for the Python and NumPy random number generators.
Expand All @@ -49,7 +49,7 @@ def enable_x64(use_x64: bool = True) -> None:
jax.config.update("jax_enable_x64", use_x64)


def set_platform(platform: str | None = None) -> None:
def set_platform(platform: Optional[str] = None) -> None:
"""
Changes platform to CPU, GPU, or TPU. This utility only takes
effect at the beginning of your program.
Expand Down Expand Up @@ -408,7 +408,7 @@ def loop_fn(collection):


def soft_vmap(
fn: Callable, xs: Any, batch_ndims: int = 1, chunk_size: int | None = None
fn: Callable, xs: Any, batch_ndims: int = 1, chunk_size: Optional[int] = None
) -> Any:
"""
Vectorizing map that maps a function `fn` over `batch_ndims` leading axes
Expand Down Expand Up @@ -466,7 +466,7 @@ def format_shapes(
*,
compute_log_prob: bool = False,
title: str = "Trace Shapes:",
last_site: str | None = None,
last_site: Optional[str] = None,
):
"""
Given the trace of a function, returns a string showing a table of the shapes of
Expand Down
2 changes: 2 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,5 +100,7 @@
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
"Programming Language :: Python :: 3.13",
],
)

0 comments on commit a4fd59b

Please sign in to comment.