You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
First off, I absolutely love JAX. It's the core engine behind our startup.
It would be fantastic to have a rank-one update to an eigenvalue decomposition of a symmetric PSD matix $A$. I.e. when $A=LDL^\top$, we can compute $\tilde{L}\tilde{D}\tilde{L}^\top=\tilde{A}=A+\rho xx^\top$ in $\mathcal{O}(n^2)$, (rather than $\mathcal{O}(n^3)$ by running the eigenvalue decomposition again on $\tilde{A}$).
In principle this is very similar to a Cholesky rank-one update, which is implemented in jax._src.lax.linalg.cholesky_update, albeit without batching rule.
Motivating case
The rank-one update to eigenvalue decompositions comes up in several rolling matrix algorithms. My own personal use-case is covariance matrix shrinkage in an exponentially weighted setting:
where I need to regularize the eigenvalues of each $\Sigma_t$. I do this for batched time-series where the matrices themselves are of moderate size (of dim 30x30 to 300x300). With a rank one update I could run a scan on the eigenvalue decomposition rather than the covariance matrix.
Required algorithms
The underlying algorithms this requires are available in scipy.linalg.cython_lapack as dlasd4 and slasd4 for CPU (which I think jaxlib references under the hood), and in MAGMA on GPU (https://github.com/kjbartel/magma/blob/master/src/dlaex3.cpp). The GPU implementation seems to just parallelize over lapack and blas calls.
Current implementation
My current CPU implementation is not batched or parallelized (literally loops over every single eigenvalue), and therefore rather slow (but still a lot faster than just running eigh on the recursion output). It is based on the some simple linear algebra (see https://math.stackexchange.com/questions/3052997/eigenvalues-of-a-rank-one-update-of-a-matrix):
fromfunctoolsimportpartialimportjaximportjax.numpyasjnpimportnumpyasnpfromjaximportArrayfromjax.typingimportArrayLikefrommy_package.lasd4_allimportdlasd4_all, slasd4_all# type: ignoredef_lasd4_all_host(d: Array, z: Array, rho: Array) ->Array:
_d=np.asarray(d)
_z=np.asarray(z)
ifd.dtype==np.float32:
returnslasd4_all(_d, _z, rho)
elifd.dtype==np.float64:
returndlasd4_all(_d, _z, rho)
else:
raiseValueError(f"Unsupported dtype {d.dtype}")
def_lasd4_all(d: Array, z: Array, rho: Array) ->Array:
# https://docs.scipy.org/doc/scipy/reference/generated/scipy.linalg.lapack.slasd4.htmldtype=jnp.promote_types(d.dtype, z.dtype)
d=jnp.asarray(d, dtype=dtype)
z=jnp.asarray(z, dtype=dtype)
rho=jnp.asarray(rho, dtype=dtype)
result_shape=jax.ShapeDtypeStruct(d.shape, dtype)
returnjax.pure_callback(
_lasd4_all_host, result_shape, d, z, rho, vmap_method="sequential"
)
def_eigh_rank_one_update(
w: Array, v: Array, z: Array, rho: Array
) ->tuple[Array, Array]:
# https://en.wikipedia.org/wiki/Bunch%E2%80%93Nielsen%E2%80%93Sorensen_formula# using VBV' = diag(w) + rho*(vz)(vz)'w_out=jnp.square(_lasd4_all(jnp.sqrt(w), z @ v, rho))
Q= (z @ v) / (w-w_out[..., None])
Q=Q/jnp.linalg.norm(Q, axis=1, keepdims=True)
v_out=v @ Q.Treturnw_out, v_out@partial(jnp.vectorize, signature="(n),(n,n),(n),()->(n),(n,n)")defeigh_rank_one_update(
w: ArrayLike, v: ArrayLike, z: ArrayLike, rho: ArrayLike
) ->tuple[Array, Array]:
"""Compute the eigenvalues and eigenvectors of a rank-one update of a symmetric matrix. Let A be a symmetric matrix with eigenvalues w and eigenvectors v, i.e. A = v @ jnp.diag(w) @ v.T. Let z be a vector and rho a scalar. This function computes the eigenvalues and eigenvectors of B = A + rho * z @ z.T. Parameters ---------- w : ArrayLike The eigenvalues of A, of shape ``(..., n)``. v : ArrayLike The eigenvectors of A, of shape ``(..., n, n)``. z : ArrayLike The vector z, of shape ``(..., n)``. rho : ArrayLike The scalar rho, of shape ``(...)``. Returns ------- tuple[Array, Array] The eigenvalues of shape ``(..., n)`` and eigenvectors of shape ``(..., n, n)``, of A + rho * z @ z.T. """w, v, z, rho=jnp.asarray(w), jnp.asarray(v), jnp.asarray(z), jnp.asarray(rho)
norm=jnp.linalg.norm(z, axis=-1, keepdims=True)
w_, v_=_eigh_rank_one_update(w, v, z/norm, rho*jnp.square(norm.squeeze(-1)))
skip=norm*rho<1e-15w=jnp.where(skip, w, w_)
v=jnp.where(skip[..., None], v, v_)
returnw, v
I would also be super happy with any guidance on how to make my current implementation go faster. As a non-computer scientist, I was not able to get a firm enough grip on C++/Cuda/XLA/JAX internals to implement this in a smarter/more parallel way myself, but if there is any way I can help with this feature, I would love to try.
The text was updated successfully, but these errors were encountered:
One suggestion that might improve performance would be to push the batching logic into Cython. In other words, perhaps you could update slasd4_all (etc.) to support 2-D inputs and then have a nested for-loop in the body. Then you could use vmap_method='broadcast_all' instead of vmap_method='sequential'. This might help because the sequential method results in a separate Python call (with the associated overhead) for each element in the batch, instead of just a single "vectorized" Python call.
On the longer term, it's a reasonable feature request for JAX to support this operation out-of-the-box, but it would need a strong argument for the team to implement it at high priority. So, it would be useful to make sure that your current implementation works well enough for now!
Hi team,
First off, I absolutely love JAX. It's the core engine behind our startup.
It would be fantastic to have a rank-one update to an eigenvalue decomposition of a symmetric PSD matix$A$ . I.e. when $A=LDL^\top$ , we can compute $\tilde{L}\tilde{D}\tilde{L}^\top=\tilde{A}=A+\rho xx^\top$ in $\mathcal{O}(n^2)$ , (rather than $\mathcal{O}(n^3)$ by running the eigenvalue decomposition again on $\tilde{A}$ ).
In principle this is very similar to a Cholesky rank-one update, which is implemented in
jax._src.lax.linalg.cholesky_update
, albeit without batching rule.Motivating case
The rank-one update to eigenvalue decompositions comes up in several rolling matrix algorithms. My own personal use-case is covariance matrix shrinkage in an exponentially weighted setting:
where I need to regularize the eigenvalues of each$\Sigma_t$ . I do this for batched time-series where the matrices themselves are of moderate size (of dim 30x30 to 300x300). With a rank one update I could run a scan on the eigenvalue decomposition rather than the covariance matrix.
Required algorithms
The underlying algorithms this requires are available in
scipy.linalg.cython_lapack
asdlasd4
andslasd4
for CPU (which I think jaxlib references under the hood), and in MAGMA on GPU (https://github.com/kjbartel/magma/blob/master/src/dlaex3.cpp). The GPU implementation seems to just parallelize over lapack and blas calls.Current implementation
My current CPU implementation is not batched or parallelized (literally loops over every single eigenvalue), and therefore rather slow (but still a lot faster than just running
eigh
on the recursion output). It is based on the some simple linear algebra (see https://math.stackexchange.com/questions/3052997/eigenvalues-of-a-rank-one-update-of-a-matrix):and then wrapping in a callback:
I would also be super happy with any guidance on how to make my current implementation go faster. As a non-computer scientist, I was not able to get a firm enough grip on C++/Cuda/XLA/JAX internals to implement this in a smarter/more parallel way myself, but if there is any way I can help with this feature, I would love to try.
The text was updated successfully, but these errors were encountered: