Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rank-one updates to eigenvalue decompositions #25057

Open
mishavanbeek opened this issue Nov 22, 2024 · 2 comments
Open

Rank-one updates to eigenvalue decompositions #25057

mishavanbeek opened this issue Nov 22, 2024 · 2 comments
Labels
enhancement New feature or request

Comments

@mishavanbeek
Copy link

Hi team,

First off, I absolutely love JAX. It's the core engine behind our startup.

It would be fantastic to have a rank-one update to an eigenvalue decomposition of a symmetric PSD matix $A$. I.e. when $A=LDL^\top$, we can compute $\tilde{L}\tilde{D}\tilde{L}^\top=\tilde{A}=A+\rho xx^\top$ in $\mathcal{O}(n^2)$, (rather than $\mathcal{O}(n^3)$ by running the eigenvalue decomposition again on $\tilde{A}$).

In principle this is very similar to a Cholesky rank-one update, which is implemented in jax._src.lax.linalg.cholesky_update, albeit without batching rule.

Motivating case
The rank-one update to eigenvalue decompositions comes up in several rolling matrix algorithms. My own personal use-case is covariance matrix shrinkage in an exponentially weighted setting:

$$\Sigma_{t+1}=\lambda\Sigma_t + (1-\lambda)x_tx_t^\top,$$

where I need to regularize the eigenvalues of each $\Sigma_t$. I do this for batched time-series where the matrices themselves are of moderate size (of dim 30x30 to 300x300). With a rank one update I could run a scan on the eigenvalue decomposition rather than the covariance matrix.

Required algorithms
The underlying algorithms this requires are available in scipy.linalg.cython_lapack as dlasd4 and slasd4 for CPU (which I think jaxlib references under the hood), and in MAGMA on GPU (https://github.com/kjbartel/magma/blob/master/src/dlaex3.cpp). The GPU implementation seems to just parallelize over lapack and blas calls.

Current implementation
My current CPU implementation is not batched or parallelized (literally loops over every single eigenvalue), and therefore rather slow (but still a lot faster than just running eigh on the recursion output). It is based on the some simple linear algebra (see https://math.stackexchange.com/questions/3052997/eigenvalues-of-a-rank-one-update-of-a-matrix):

# lasd4_all.pyx, for single-precision only for brevity

cimport cython
cimport numpy as cnp
from scipy.linalg.cython_lapack cimport dlasd4, slasd4

import numpy as np

DTYPE_SINGLE = np.float32
cnp.import_array()
ctypedef cnp.float32_t DTYPE_SINGLE_t


@cython.boundscheck(False)
@cython.wraparound(False)
def slasd4_all(
    cnp.ndarray[DTYPE_SINGLE_t, ndim=1] d, 
    cnp.ndarray[DTYPE_SINGLE_t, ndim=1] z,
    float rho
):
    cdef:
        int n = d.shape[0]
        int i
        float sigma
        int info
        cnp.ndarray[DTYPE_SINGLE_t, ndim=1] work = np.zeros(n, dtype=DTYPE_SINGLE)
        cnp.ndarray[DTYPE_SINGLE_t, ndim=1] delta = np.zeros(n, dtype=DTYPE_SINGLE)
        cnp.ndarray[DTYPE_SINGLE_t, ndim=1] sigmas = np.zeros(n, dtype=DTYPE_SINGLE)

    for i in range(1, n + 1):
        slasd4(&n, &i, <float*>d.data, <float*>z.data, <float*>delta.data, &rho,
            &sigmas[i - 1], <float*>work.data, &info)

    return sigmas

and then wrapping in a callback:

from functools import partial

import jax
import jax.numpy as jnp
import numpy as np
from jax import Array
from jax.typing import ArrayLike

from my_package.lasd4_all import dlasd4_all, slasd4_all  # type: ignore


def _lasd4_all_host(d: Array, z: Array, rho: Array) -> Array:
    _d = np.asarray(d)
    _z = np.asarray(z)
    if d.dtype == np.float32:
        return slasd4_all(_d, _z, rho)
    elif d.dtype == np.float64:
        return dlasd4_all(_d, _z, rho)
    else:
        raise ValueError(f"Unsupported dtype {d.dtype}")


def _lasd4_all(d: Array, z: Array, rho: Array) -> Array:
    # https://docs.scipy.org/doc/scipy/reference/generated/scipy.linalg.lapack.slasd4.html
    dtype = jnp.promote_types(d.dtype, z.dtype)
    d = jnp.asarray(d, dtype=dtype)
    z = jnp.asarray(z, dtype=dtype)
    rho = jnp.asarray(rho, dtype=dtype)
    result_shape = jax.ShapeDtypeStruct(d.shape, dtype)
    return jax.pure_callback(
        _lasd4_all_host, result_shape, d, z, rho, vmap_method="sequential"
    )


def _eigh_rank_one_update(
    w: Array, v: Array, z: Array, rho: Array
) -> tuple[Array, Array]:
    # https://en.wikipedia.org/wiki/Bunch%E2%80%93Nielsen%E2%80%93Sorensen_formula
    # using VBV' = diag(w) + rho*(vz)(vz)'
    w_out = jnp.square(_lasd4_all(jnp.sqrt(w), z @ v, rho))
    Q = (z @ v) / (w - w_out[..., None])
    Q = Q / jnp.linalg.norm(Q, axis=1, keepdims=True)
    v_out = v @ Q.T
    return w_out, v_out


@partial(jnp.vectorize, signature="(n),(n,n),(n),()->(n),(n,n)")
def eigh_rank_one_update(
    w: ArrayLike, v: ArrayLike, z: ArrayLike, rho: ArrayLike
) -> tuple[Array, Array]:
    """Compute the eigenvalues and eigenvectors of a rank-one update of a symmetric
    matrix.

    Let A be a symmetric matrix with eigenvalues w and eigenvectors v, i.e.
    A = v @ jnp.diag(w) @ v.T. Let z be a vector and rho a scalar. This function
    computes the eigenvalues and eigenvectors of B = A + rho * z @ z.T.

    Parameters
    ----------
    w : ArrayLike
        The eigenvalues of A, of shape ``(..., n)``.
    v : ArrayLike
        The eigenvectors of A, of shape ``(..., n, n)``.
    z : ArrayLike
        The vector z, of shape ``(..., n)``.
    rho : ArrayLike
        The scalar rho, of shape ``(...)``.

    Returns
    -------
    tuple[Array, Array]
        The eigenvalues of shape ``(..., n)`` and eigenvectors of shape ``(..., n, n)``,
        of A + rho * z @ z.T.
    """
    w, v, z, rho = jnp.asarray(w), jnp.asarray(v), jnp.asarray(z), jnp.asarray(rho)
    norm = jnp.linalg.norm(z, axis=-1, keepdims=True)
    w_, v_ = _eigh_rank_one_update(w, v, z / norm, rho * jnp.square(norm.squeeze(-1)))
    skip = norm * rho < 1e-15
    w = jnp.where(skip, w, w_)
    v = jnp.where(skip[..., None], v, v_)
    return w, v

I would also be super happy with any guidance on how to make my current implementation go faster. As a non-computer scientist, I was not able to get a firm enough grip on C++/Cuda/XLA/JAX internals to implement this in a smarter/more parallel way myself, but if there is any way I can help with this feature, I would love to try.

@mishavanbeek mishavanbeek added the enhancement New feature or request label Nov 22, 2024
@dfm
Copy link
Collaborator

dfm commented Nov 22, 2024

The approach you're using here looks good to me!

One suggestion that might improve performance would be to push the batching logic into Cython. In other words, perhaps you could update slasd4_all (etc.) to support 2-D inputs and then have a nested for-loop in the body. Then you could use vmap_method='broadcast_all' instead of vmap_method='sequential'. This might help because the sequential method results in a separate Python call (with the associated overhead) for each element in the batch, instead of just a single "vectorized" Python call.

On the longer term, it's a reasonable feature request for JAX to support this operation out-of-the-box, but it would need a strong argument for the team to implement it at high priority. So, it would be useful to make sure that your current implementation works well enough for now!

@mishavanbeek
Copy link
Author

Thanks Dan! I appreciate the tip, and I understand that this is a pretty niche matrix operation.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

2 participants