-
Notifications
You must be signed in to change notification settings - Fork 30
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Incorrect use of init inside the unflatten rule of CoLA custom pytrees causes issues [JAX] #20
Comments
I am finding this is also causing multiple issues when it comes to things like using cola.ops.partial. |
Ah, thanks for bringing this to our attention @adam-hartshorne. We haven't had an official release of the library but will be doing so shortly (with first pypi release later this week, and full release with accompanying paper likely the following week). |
I agree that directly inheriting probably isn't the best solution that won't make CoLA as general purpose as possible. However, I believe done correctly should mean CoLA still remains compatible with Equinox (or several other similar libraries that provide similar functionality) as ultimately all of these revolve around PyTree structures e.g. GPJax uses their own custom PyTree class and remains compatible with Equinox. |
Hi -- creator of Equinox here. FWIW, one possibility might be something like: try:
import jax
except ImportError:
from equinox import Module
else:
class Module: pass
class LinearOperator(Module): ... But alternatively you absolutely could register your own pytrees too. I'd recommend taking a look at the implementation of Side note, whilst I'm here -- I'd be curious to hear more about comparisons between CoLA and Lineax. I wasn't aware of CoLA whilst writing Lineax. |
Glad to see you here! I did experiment with doing something like you suggest with equinox (conditionally) as the base class, but I ran into several issues since the fields of Instead I took some inspiration from your comments in jax-ml/jax#16170 and overwrite the Thoughts on CoLA vs Lineax: This also allows certain approaches to be more cleanly implemented. For example, calling a Another key difference I would say, while not as relevant for equinox and google, is native pytorch support. While I like using Jax, we have a lot of use cases and potential need for pytorch support. In terms of specific linear algebra operations, as you mention we don't yet have support for pseudoinverses (hopefully soon to come) or BiCGStab, but we also have more support for other kinds of linear algebra operations like (trace and diagonal estimation, exp(A)v, logdet), though I wouldn't say these are fundamental differences, rather just minor differences in what's implemented in the libraries right now. CoLA's very new, while we have been working on it for some time, we have only just had our first pypi release and with our public announcement of the library still to come (in the next week or two). So it's not surprising that you didn't see CoLA, but hopefully we can learn from any insights or important design decisions you may have had while building lineax if you're willing to share :) |
Hey there! For the I think the most important edge case to handle is ensuring that bound methods are still pytrees. CoLA/Lineax: FWIW I've not yet found a case where I needed multiple dispatch. (Just as well, as this isn't very well-supported in the Python ecosystem. E.g. I believe plum/beartype doesn't always handle parametric types correctly.) That said you may find e.g. I can see that the desire to support PyTorch is an extra challenge for you. I think that's going to be very tricky, as there's a fair number of idiomatic differences between the libraries. For example, |
Address compatibility for pytrees both in Jax and PyTorch, to allow vmap, jit with LinearOperator arguments or outputs. See e.g. #20 or #26. Will also enable using CoLA operators within [equinox](https://github.com/patrick-kidger/equinox) modules. Replaces custom flattening function (which is slow and not very general) with optree (for pytorch and numpy backends) and jax.tree_utils for jax. Modifies __setattr__ of LinearOperator to record which of vars(obj) are dynamic (pytrees or arrays) or static (other). Then flatten and unflatten can separate the two. This annotation of which are static and which dynamic needs to be done during the init as discussed in this jax issue [https://github.com/google/jax/issues/16170](https://github.com/google/jax/issues/16170) (even though doing so inside flatten is ostensibly compatible with the pytree specification). LinearOperator metaclass is set to one which automatically registers each linear operator as a pytorch pytree (if installed), optree pytree, and jax pytree (if installed). Optree is used for numpy linear operators and will also eventually replace the pytorch pytrees as per the intentions of the pytorch devs. With this functionality it should also be possible to construct batched linear operators and use them in batched matrix routines. E.g. ```python Ds = vmap(cola.ops.Diagonal)(randn(10,1000)) logdets = vmap(cola.linalg.logdet)(Ds) ``` It may also be possible to simplify our custom autograd rules once LinearOperators are pytrees though I have not attempted this yet.
We've just added proper support for pytrees in jax and pytorch now. @adam-hartshorne, the above equinox code now works (on main). |
When attempting to incorporate a CoLA Linear Operator as a field in Equinox (https://github.com/patrick-kidger/equinox), as shown in the MVE below, I receive AttributeError: 'bool' object has no attribute 'dtype'.
As per patrick-kidger/equinox#453 (comment), the issue appears to be a symptom of a wider bug in CoLA tree unflattening code.
The text was updated successfully, but these errors were encountered: