Skip to content

Commit

Permalink
Suppress mypy errors
Browse files Browse the repository at this point in the history
  • Loading branch information
bwohlberg committed Nov 6, 2024
1 parent dd88cb4 commit 4feba96
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions scico/trace.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,11 +116,11 @@ def _trace_arg_repr(val: Any) -> str:
if isinstance(val, jax.Array) and not isinstance(
val, jax._src.interpreters.partial_eval.JaxprTracer
):
if call_trace.show_jax_device:
if call_trace.show_jax_device: # type: ignore
platform = list(val.devices())[0].platform # assume all of same type
devices = ",".join(map(str, sorted([d.id for d in val.devices()])))
dev_str = f"{clr_devc}{{dev={platform}({devices})}}{clr_args}"
if call_trace.show_jax_sharding and isinstance(
if call_trace.show_jax_sharding and isinstance( # type: ignore
val.sharding, jax._src.sharding_impls.PositionalSharding
):
shard_str = f"{clr_devc}{{shard={val.sharding.shape}}}{clr_args}"
Expand Down

0 comments on commit 4feba96

Please sign in to comment.