Skip to content

Commit

Permalink
Address DeprecationWarning: jax.random.KeyArray is deprecated. Use ja…
Browse files Browse the repository at this point in the history
…x.Array for annotations
  • Loading branch information
bwohlberg committed Oct 5, 2023
1 parent 216ffc8 commit 6873fd6
Show file tree
Hide file tree
Showing 4 changed files with 5 additions and 4 deletions.
2 changes: 1 addition & 1 deletion scico/flax/train/input_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from .typed_dict import DataSetDict

DType = Any
KeyArray = Union[Array, jax.random.PRNGKeyArray]
KeyArray = Union[Array, jax.Array]


class IterateData:
Expand Down
2 changes: 1 addition & 1 deletion scico/flax/train/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from .typed_dict import ConfigDict, ModelVarDict

ModuleDef = Any
KeyArray = Union[Array, jax.random.PRNGKeyArray]
KeyArray = Union[Array, jax.Array]
PyTree = Any
ArrayTree = optax.Params

Expand Down
2 changes: 1 addition & 1 deletion scico/flax/train/steps.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from .state import TrainState
from .typed_dict import DataSetDict, MetricsDict

KeyArray = Union[Array, jax.random.PRNGKeyArray]
KeyArray = Union[Array, jax.Array]
PyTree = Any


Expand Down
3 changes: 2 additions & 1 deletion scico/flax/train/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,11 @@
from .typed_dict import ConfigDict, DataSetDict, MetricsDict, ModelVarDict

ModuleDef = Any
KeyArray = Union[Array, jax.random.PRNGKeyArray]
KeyArray = Union[Array, jax.Array]
PyTree = Any
DType = Any


# sync across replicas
def sync_batch_stats(state: TrainState) -> TrainState:
"""Sync the batch statistics across replicas."""
Expand Down

0 comments on commit 6873fd6

Please sign in to comment.