From 4feba96f5a054248bc0a9e5882887db259151057 Mon Sep 17 00:00:00 2001 From: Brendt Wohlberg Date: Tue, 5 Nov 2024 18:02:21 -0700 Subject: [PATCH] Suppress mypy errors --- scico/trace.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/scico/trace.py b/scico/trace.py index d58924ee..2c38f69f 100644 --- a/scico/trace.py +++ b/scico/trace.py @@ -116,11 +116,11 @@ def _trace_arg_repr(val: Any) -> str: if isinstance(val, jax.Array) and not isinstance( val, jax._src.interpreters.partial_eval.JaxprTracer ): - if call_trace.show_jax_device: + if call_trace.show_jax_device: # type: ignore platform = list(val.devices())[0].platform # assume all of same type devices = ",".join(map(str, sorted([d.id for d in val.devices()]))) dev_str = f"{clr_devc}{{dev={platform}({devices})}}{clr_args}" - if call_trace.show_jax_sharding and isinstance( + if call_trace.show_jax_sharding and isinstance( # type: ignore val.sharding, jax._src.sharding_impls.PositionalSharding ): shard_str = f"{clr_devc}{{shard={val.sharding.shape}}}{clr_args}"