Skip to content

Commit

Permalink
Silence some pytype errors.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 610876480
  • Loading branch information
rchen152 authored and KfacJaxDev committed Feb 27, 2024
1 parent 5c64ae7 commit 8ed9d27
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 2 deletions.
2 changes: 1 addition & 1 deletion kfac_jax/_src/curvature_blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -1107,7 +1107,7 @@ def array_to_parameters_shaped_list(self, array: Array) -> Tuple[Array, ...]:
@property
def array_shape(self) -> Shape:
"""The shape of the single non axis grouped array."""
avals = [jnp.zeros(v.aval.shape) for v in self.parameter_variables]
avals = [jnp.zeros(v.aval.shape) for v in self.parameter_variables] # pytype: disable=attribute-error # always-use-property-annotation
return self.parameters_shaped_list_to_array(avals).shape

@property
Expand Down
2 changes: 1 addition & 1 deletion kfac_jax/_src/tracer.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,7 @@ def __eq__(self, other: ProcJaxpr) -> bool:
return False

# Verify whether parameter shapes are equivalent
if any(p_i.aval.shape != p_j.aval.shape
if any(p_i.aval.shape != p_j.aval.shape # pytype: disable=attribute-error # always-use-property-annotation
for p_i, p_j in zip(self.params_vars_flat, other.params_vars_flat)):
return False

Expand Down

0 comments on commit 8ed9d27

Please sign in to comment.