Skip to content

Commit

Permalink
small fix
Browse files Browse the repository at this point in the history
  • Loading branch information
evanatyourservice committed Dec 15, 2024
1 parent f7f2382 commit 07781e6
Showing 1 changed file with 2 additions and 1 deletion.
3 changes: 2 additions & 1 deletion src/levanter/optim/kron.py
Original file line number Diff line number Diff line change
Expand Up @@ -507,7 +507,7 @@ def broadcast_qs(_, ps, q, s):

if return_partition_specs_only:
return dict(
key=jax.random.PRNGKey(jax.process_index()),
key=PartitionSpec(),
count=PartitionSpec(),
mu=mu_sharding,
Qs_preconditioners=Qs_sharding,
Expand All @@ -516,6 +516,7 @@ def broadcast_qs(_, ps, q, s):
)

return dict(
key=jax.random.PRNGKey(jax.process_index()),
count=jnp.zeros([], jnp.int32),
mu=mu,
Qs_preconditioners=Qs,
Expand Down

0 comments on commit 07781e6

Please sign in to comment.