-
Notifications
You must be signed in to change notification settings - Fork 2.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
jax.vmap
fails for custom PyTree
#16170
Comments
Hello @mattjj and @hawkinsp, (tagging because you are mentioned in the code). I took a look at the way If I understand correctly, the core ( I came up with a simpler implementation, which does not rely on generators and does not require to replace the leaves of the trees to flatten the axes. import jax
import jax.tree_util as jtu
def flatten_axes(tree, axes): # /!\ uses the tree instead of the treedef
flat_axes = []
def add_axes(axis, x):
flat_axes.extend([axis] * len(jtu.tree_leaves(x)))
jtu.tree_map(add_axes, axes, tree)
return flat_axes
def vmap(fun, in_axes, out_axes, axis_size, axis_name, spmd_axis_name):
def wrapped(*ins):
flat_in_axes = flatten_axes(ins, in_axes)
flat_ins, treedef_ins = jtu.tree_flatten(ins)
if axis_size is None:
axis_size = mystery_logic(flat_ins, flat_in_axes)
flat_out_axes = []
store = []
def flat_fun(flat_ins):
ins = jtu.tree_unflatten(treedef_ins, flat_ins)
outs = fun(*ins)
flat_out_axes.extend(flatten_axes(outs, out_axes))
flat_outs, treedef_outs = jtu.tree_flatten(outs)
store.append(treedef_outs)
return flat_outs
flat_outs = flat_vmap(
flat_fun,
flat_in_axes,
flat_out_axes, # filled when flat_fun is called
axis_size,
axis_name,
spmd_axis_name,
)(flat_ins)
treedef_outs = store.pop()
return jtu.tree_unflatten(treedef_outs, flat_outs)
return wrapped I don't really understand all the technicalities in |
The issue here is that your pytree flattening is conditioned on whether the attributes are array instances – this breaks def tree_flatten(self):
return [self.x], [None] |
Hello @jakevdp, thank you for your comment, but I actually need that logic to make smart pytrees that are easy to use with I know that The new implementation of |
OK, thanks for the context. I'm assigning @mattjj who could take a look at the alternate vmap implementation proposal. |
@francois-rozet -- you may like Equinox, which does many of the things you're discussing here. |
Hello, @patrick-kidger I know equinox, I am actually trying to improve it 😅 (I want to make it more compatible with the default jax transformations, without having to specify static fields), I'll get to you when I have a working proof of concept! |
Ah, I've actually figure out the answer to that one! Don't use dataclasses (like Equinox does). Then override Either way, then inspect the value specifically at init time. If it's an array, mark it dynamic. If it's anything else, mark it static. This preserves pytree semantics as the inspection is only done at init time, and not later during flattening. (After which e.g. leaves may have substituted out with I've been contemplating changing Equinox over to the above approach. It's not clear to me what the consequences to backward compatibility are, though. |
I think this approach would break if you want a "list of arrays", a "dict of modules", a list of Partial objects with modules as functions, or user defined pytrees ... inside a Module. It is impossible to enumerate all cases because there is an infinity. And modifying I have a proof of concept in this repo: https://github.com/francois-rozet/inox. The name is a tribute to equinox, and means stainless in French. The core are the There is a
|
I think this is still doable, by flattening every field into a list of leaves, and checking each one. All non-arrays can be wrapped into a This would also be quite nice in that it means we might be able to avoid the confusion over when to use static fields, by simply not having them be part of the public API any more. WDYT?
The problem is that this just isn't compatible with JAX's model of what a pytree is. (I.e. that the structure does not depend on the types of the leaves.) And indeed Equinox makes heavy use of this invariant too: for example
Haha, thank you! That's great to see.
I like it! I'm now wondering why I ever called it I would caution against your approach to statefulness, using in-place updates. This gets pretty hairy around JAX transforms: if I do my_model = ... # uses batch norm internally
leaves, treedef = jtu.tree_flatten(my_model)
my_model2 = jtu.tree_unflatten(treedef, leaves)
my_model2(...) then the original |
@patrick-kidger Let's take this discussion elsewhere (francois-rozet/inox#1), as it is not directly related to the issue I submitted. |
Address compatibility for pytrees both in Jax and PyTorch, to allow vmap, jit with LinearOperator arguments or outputs. See e.g. #20 or #26. Will also enable using CoLA operators within [equinox](https://github.com/patrick-kidger/equinox) modules. Replaces custom flattening function (which is slow and not very general) with optree (for pytorch and numpy backends) and jax.tree_utils for jax. Modifies __setattr__ of LinearOperator to record which of vars(obj) are dynamic (pytrees or arrays) or static (other). Then flatten and unflatten can separate the two. This annotation of which are static and which dynamic needs to be done during the init as discussed in this jax issue [https://github.com/google/jax/issues/16170](https://github.com/google/jax/issues/16170) (even though doing so inside flatten is ostensibly compatible with the pytree specification). LinearOperator metaclass is set to one which automatically registers each linear operator as a pytorch pytree (if installed), optree pytree, and jax pytree (if installed). Optree is used for numpy linear operators and will also eventually replace the pytorch pytrees as per the intentions of the pytorch devs. With this functionality it should also be possible to construct batched linear operators and use them in batched matrix routines. E.g. ```python Ds = vmap(cola.ops.Diagonal)(randn(10,1000)) logdets = vmap(cola.linalg.logdet)(Ds) ``` It may also be possible to simplify our custom autograd rules once LinearOperators are pytrees though I have not attempted this yet.
Hello @jakevdp, @mattjj 👋 I have found a workaround for this issue which makes custom auto-detect non-array leaves PyTrees compatible with the current import jax
@jax.tree_util.register_pytree_node_class
class Custom:
def __init__(self, key, switch=True):
if switch:
self.x = jax.random.normal(key) # not-static
else:
self.x = 'static'
def tree_flatten(self):
if isinstance(self.x, jax.Array) or type(self.x) is object: # also check whether x is an object()
return [self.x], [None]
else:
return [None], [self.x]
@classmethod
def tree_unflatten(cls, static, children):
y, x = static[0], children[0]
x = x if y is None else y
self = object.__new__(cls)
self.x = x
return self
key = jax.random.PRNGKey(0)
keys = jax.random.split(key, 4)
jax.vmap(Custom)(keys) However, I still think that the sketch implementation I propose at #16170 (comment) is better than the current one because it is (a) simpler as it does not uses generators, (b) faster as it flattens/unflattens less trees and (c) allows for arbitrary custom PyTrees as it does not assume the tree structure is independent from the leaf types. Note that, in fact, JAX already does not assume that the tree structure is independent from the leaf type as replacing a leaf by |
Won't the pytree registry help here? |
I am not sure what you mean. |
@francois-rozet -- I would really strongly discourage this leaf-type-detection approach. Since JAX pytrees can in principle have leaves of any type, then third-party code can and does make use of this. Your work may be compatible with core JAX (or at least any errors are passing silently, rather than failing loudly), but it might break when used with other libraries. As an example of another library running afoul of this, see google-deepmind/distrax#193. Again there is leaf type detection during flattening, and this then breaks things later. I think if you're interested in doing auto-[dynamic/static] leaf type detection, then I'd propose doing it during |
Description
I implemented a custom PyTree class that automatically separates array and non-array leaves.
jax.vmap
fails when the input/output is an instance of that class.fails with
I believe the error comes from the way
jax.vmap
tries to assign vectorized axis to the tree leaves.https://github.com/google/jax/blob/ae9160a4e9f14992c9f53a38a1aeaf146eacbf16/jax/_src/api_util.py#L395-L400
Here, it is assumed that unflattening the
treedef
withobject()
leaves instead of the original leaves, then flattening the obtained tree will lead to the same (number of) leaves. This is not the case with my custom PyTree. I see two ways of fixing this:object()
which can be used in custom PyTrees to ensure that leaves stay the same.jax.vmap
to not rely on placeholder leaves.What jax/jaxlib version are you using?
jax 0.4.10, jaxlib 0.4.10
Which accelerator(s) are you using?
CPU
Additional system info
Python 3.9.16, Ubuntu 22.04
NVIDIA GPU info
No response
The text was updated successfully, but these errors were encountered: