Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Incorrect use of init inside the unflatten rule of CoLA custom pytrees causes issues [JAX] #20

Closed
adam-hartshorne opened this issue Aug 20, 2023 · 7 comments · Fixed by #29
Assignees
Labels
bug Something isn't working

Comments

@adam-hartshorne
Copy link

adam-hartshorne commented Aug 20, 2023

When attempting to incorporate a CoLA Linear Operator as a field in Equinox (https://github.com/patrick-kidger/equinox), as shown in the MVE below, I receive AttributeError: 'bool' object has no attribute 'dtype'.

import jax
import jax.numpy as jnp
import jax.tree_util as jtu
import jax.random as jr
import optax
import equinox as eqx
import cola

class MyModule(eqx.Module):

    lazy_A : cola.ops.LinearOperator

    def __init__(self, A):
        self.lazy_A = cola.fns.lazify(A)

    def __call__(self, x):
        return self.lazy_A @ x

seed = jr.PRNGKey(0)
A = jr.normal(seed, (10, 10))
X = jnp.ones((10, 1))
model = MyModule(A)
result = eqx.filter(model, eqx.is_inexact_array)

  File "/media/adam/shared_drive/PycharmProjects/test_equinox_lazy_variable/test_equinox_lazy_variable.py", line 24, in <module>
    result = eqx.filter(model, eqx.is_inexact_array)
  File "/home/adam/anaconda3/envs/jax_torch_latest/lib/python3.10/site-packages/equinox/_filters.py", line 129, in filter
    filter_tree = jtu.tree_map(_make_filter_tree(is_leaf), filter_spec, pytree)
  File "/home/adam/anaconda3/envs/jax_torch_latest/lib/python3.10/site-packages/jax/_src/tree_util.py", line 252, in tree_map
    return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
  File "/home/adam/anaconda3/envs/jax_torch_latest/lib/python3.10/site-packages/jax/_src/tree_util.py", line 252, in <genexpr>
    return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
  File "/home/adam/anaconda3/envs/jax_torch_latest/lib/python3.10/site-packages/equinox/_filters.py", line 71, in _filter_tree
    return jtu.tree_map(mask, arg, is_leaf=is_leaf)
  File "/home/adam/anaconda3/envs/jax_torch_latest/lib/python3.10/site-packages/jax/_src/tree_util.py", line 252, in tree_map
    return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
  File "/home/adam/anaconda3/envs/jax_torch_latest/lib/python3.10/site-packages/cola/ops/operator_base.py", line 245, in tree_unflatten
    return cls(*new_args, **aux[0])
  File "/home/adam/anaconda3/envs/jax_torch_latest/lib/python3.10/site-packages/cola/ops/operators.py", line 21, in __init__
    super().__init__(dtype=A.dtype, shape=A.shape)
AttributeError: 'bool' object has no attribute 'dtype'

As per patrick-kidger/equinox#453 (comment), the issue appears to be a symptom of a wider bug in CoLA tree unflattening code.

The issue is that they're using init inside the unflatten rule of their pytrees:

return cls(*new_args, **aux[0])

This is a common mistake when implementing custom pytrees in JAX; see the documentation here: https://jax.readthedocs.io/en/latest/pytrees.html#custom-pytrees-and-initialization

They could fix this by either (a) having their LinearOperator inherit from eqx.Module, or (b) unflattening via new and setattr; see the Equinox implementation here for reference.

@adam-hartshorne adam-hartshorne added the bug Something isn't working label Aug 20, 2023
@adam-hartshorne adam-hartshorne changed the title Incorrect use of init inside the unflatten rule of CoLA custom pytrees causes issues Incorrect use of init inside the unflatten rule of CoLA custom pytrees causes issues [JAX] Aug 20, 2023
@adam-hartshorne
Copy link
Author

I am finding this is also causing multiple issues when it comes to things like using cola.ops.partial.

@mfinzi
Copy link
Collaborator

mfinzi commented Aug 23, 2023

Ah, thanks for bringing this to our attention @adam-hartshorne.
We may not be able to directly inherit from eqx.Module since we need to also support pytorch with no jax installation. But we should be able to sort out this out with __new__, I will have a look with @AndPotap. It seems we need some additional tests here (tree_flattening and unflattening had to be written separately for custom vjps so as to support pytorch).

We haven't had an official release of the library but will be doing so shortly (with first pypi release later this week, and full release with accompanying paper likely the following week).

@mfinzi mfinzi self-assigned this Aug 23, 2023
@adam-hartshorne
Copy link
Author

adam-hartshorne commented Aug 23, 2023

I agree that directly inheriting probably isn't the best solution that won't make CoLA as general purpose as possible.

However, I believe done correctly should mean CoLA still remains compatible with Equinox (or several other similar libraries that provide similar functionality) as ultimately all of these revolve around PyTree structures e.g. GPJax uses their own custom PyTree class and remains compatible with Equinox.

@patrick-kidger
Copy link

patrick-kidger commented Aug 26, 2023

Hi -- creator of Equinox here.

FWIW, one possibility might be something like:

try:
    import jax
except ImportError:
    from equinox import Module
else:
    class Module: pass

class LinearOperator(Module): ...

But alternatively you absolutely could register your own pytrees too. I'd recommend taking a look at the implementation of equinox.Module to see how it's done / how to handle various edge cases.


Side note, whilst I'm here -- I'd be curious to hear more about comparisons between CoLA and Lineax. I wasn't aware of CoLA whilst writing Lineax.

@mfinzi
Copy link
Collaborator

mfinzi commented Aug 26, 2023

Hi @patrick-kidger

Glad to see you here! I did experiment with doing something like you suggest with equinox (conditionally) as the base class, but I ran into several issues since the fields of LinearOperators aren't declared explicitly (and we would like to avoid enforcing this for the creation of new operators if possible). Thinking more about it, I think we want to support pytrees in pytorch too so as to have feature parity, and so we may as well use the same approach to both.

Instead I took some inspiration from your comments in jax-ml/jax#16170 and overwrite the __setattr__ to mark fields as either static or dynamic during the init. It's not the cleanest solution, but I could not find another way to get it to work without significant intrusion on the public LinearOperator interface. I have not yet merged this implementation (PR #29) back into main yet (a few more details are pending) but it works with vmap/jit in some simple cases that I tested and I believe the approach should be able to work more generally. Though maybe you would have some thoughts on edge cases as I see that there are numerous tricky edge cases that you have handled with registering Modules in a general way in equinox.

Thoughts on CoLA vs Lineax:
With CoLA, one of the things that we wanted to bring to the table first and foremost is specialization rules that can be used to accelerate a given algorithm for a given structure. For example $(A\otimes B)^{-1} = A^{-1} \otimes B^{-1}$, or $\mathrm{tr}(Perm[\sigma])) = \sum_i \mathbb{1}_{[\sigma_i=i]}$. Rather than using objects and inheritance (which limits the extension of these rules in and outside of the library), we add these rules with multiple dispatch and these can be extended outside the library. This is a bit like your AutoLinearSolver, but with a more general and extensible mechanism, which allows us to use it all over the library.

image

This also allows certain approaches to be more cleanly implemented. For example, calling a A = cola.cholesky_decomposed(A) means that A will be a Product[Triangular,Triangular], meaning that calling logdet(A: Product) will split the logdet over the product, and then use the logdet(A: Triangular) rule using the diagonal the triangular matrix.

Another key difference I would say, while not as relevant for equinox and google, is native pytorch support. While I like using Jax, we have a lot of use cases and potential need for pytorch support.

In terms of specific linear algebra operations, as you mention we don't yet have support for pseudoinverses (hopefully soon to come) or BiCGStab, but we also have more support for other kinds of linear algebra operations like (trace and diagonal estimation, exp(A)v, logdet), though I wouldn't say these are fundamental differences, rather just minor differences in what's implemented in the libraries right now.

CoLA's very new, while we have been working on it for some time, we have only just had our first pypi release and with our public announcement of the library still to come (in the next week or two). So it's not surprising that you didn't see CoLA, but hopefully we can learn from any insights or important design decisions you may have had while building lineax if you're willing to share :)

@mfinzi mfinzi linked a pull request Aug 26, 2023 that will close this issue
@patrick-kidger
Copy link

Hey there!

For the __setattr__ approach: I think this should work, but I've not tried it in practice. Try it and let us know?

I think the most important edge case to handle is ensuring that bound methods are still pytrees.

CoLA/Lineax: FWIW I've not yet found a case where I needed multiple dispatch. (Just as well, as this isn't very well-supported in the Python ecosystem. E.g. I believe plum/beartype doesn't always handle parametric types correctly.) That said you may find e.g. lineax.{linearise,materialise,...} interesting, as these are module-level functions using single dispatch. I had similar concerns to you, in that I didn't want to insist on a particular set of operations ahead of time, by making them methods.

I can see that the desire to support PyTorch is an extra challenge for you. I think that's going to be very tricky, as there's a fair number of idiomatic differences between the libraries. For example, lineax.linear_solve actually calls out to a custom JAX primitive, in order to provide the appropriate transpoes rules, or the appropriate JVP rule when the solver supports dependent rows/columns. For this kind of reason, I've mostly given up on cross-framework libraries.

AndPotap added a commit that referenced this issue Aug 28, 2023
Address compatibility for pytrees both in Jax and PyTorch, to allow
vmap, jit with LinearOperator arguments or outputs. See e.g. #20 or #26.
Will also enable using CoLA operators within
[equinox](https://github.com/patrick-kidger/equinox) modules.

Replaces custom flattening function (which is slow and not very general)
with optree (for pytorch and numpy backends) and jax.tree_utils for jax.
Modifies __setattr__ of LinearOperator to record which of vars(obj) are
dynamic (pytrees or arrays) or static (other).
Then flatten and unflatten can separate the two. This annotation of
which are static and which dynamic needs to be done during the init as
discussed in this jax issue
[https://github.com/google/jax/issues/16170](https://github.com/google/jax/issues/16170)
(even though doing so inside flatten is ostensibly compatible with the
pytree specification).
LinearOperator metaclass is set to one which automatically registers
each linear operator as a pytorch pytree (if installed), optree pytree,
and jax pytree (if installed). Optree is used for numpy linear operators
and will also eventually replace the pytorch pytrees as per the
intentions of the pytorch devs.


With this functionality it should also be possible to construct batched
linear operators and use them in batched matrix routines.
E.g. 
```python
Ds = vmap(cola.ops.Diagonal)(randn(10,1000))
logdets = vmap(cola.linalg.logdet)(Ds)
```

It may also be possible to simplify our custom autograd rules once
LinearOperators are pytrees though I have not attempted this yet.
@mfinzi
Copy link
Collaborator

mfinzi commented Aug 28, 2023

We've just added proper support for pytrees in jax and pytorch now. @adam-hartshorne, the above equinox code now works (on main).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants