Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

API example for @compact is type unstable #973

Open
jesseylin opened this issue Oct 7, 2024 · 2 comments
Open

API example for @compact is type unstable #973

jesseylin opened this issue Oct 7, 2024 · 2 comments

Comments

@jesseylin
Copy link

The following example from the API documentation (copy-pasted for reference), which uses
w2 = [Dense(128, 128) for i in 1:nlayers]
as opposed to
w2 = Chain([Dense(128, 128) for i in 1:nlayers])
gives Zygote gradients which are type unstable.

Either

  1. the example should be updated
  2. the example should come with a warning that this is type unstable
  3. @compact should be able to deal with this in a type-stable manner (evidently, this point is an actual bug more than just a documentation issue)
julia> n_in = 1;

julia> n_out = 1;

julia> nlayers = 3;

julia> model = @compact(w1=Dense(n_in, 128),
           w2=[Dense(128, 128) for i in 1:nlayers], w3=Dense(128, n_out), act=relu) do x
           embed = act.(w1(x))
           for w in w2
               embed = act.(w(embed))
           end
           out = w3(embed)
           @return out
       end
@compact(
    w1 = Dense(1 => 128),               # 256 parameters
    w2 = NamedTuple(
        1 = Dense(128 => 128),          # 16_512 parameters
        2 = Dense(128 => 128),          # 16_512 parameters
        3 = Dense(128 => 128),          # 16_512 parameters
    ),
    w3 = Dense(128 => 1),               # 129 parameters
    act = relu,
) do x
    embed = act.(w1(x))
    for w = w2
        embed = act.(w(embed))
    end
    out = w3(embed)
    return out
end       # Total: 49_921 parameters,
          #        plus 1 states.

julia> ps, st = Lux.setup(Xoshiro(0), model);

julia> size(first(model(randn(n_in, 32), ps, st)))  # 1×32 Matrix as output.
(1, 32)

julia> using Optimisers, Zygote

julia> x_data = collect(-2.0f0:0.1f0:2.0f0)';

julia> y_data = 2 .* x_data .- x_data .^ 3;

julia> optim = Optimisers.setup(Adam(), ps);

julia> loss_initial = sum(abs2, first(model(x_data, ps, st)) .- y_data);

julia> for epoch in 1:1000
           loss, gs = Zygote.withgradient(
               ps -> sum(abs2, first(model(x_data, ps, st)) .- y_data), ps)
           Optimisers.update!(optim, ps, gs[1])
       end;

julia> loss_final = sum(abs2, first(model(x_data, ps, st)) .- y_data);

julia> loss_initial > loss_final
true
@avik-pal
Copy link
Member

avik-pal commented Oct 7, 2024

It is Zygote, that is inherently type-unstable. We make a lot of effort to ensure some of the containers are type-stable for Zygote, but no guarantees whatsoever in that regard are given.

the example should be updated

That example shows "see you can pass in a list or tuple of layers". The forward pass will be type-stable and so will gradients if computed using something like Enzyme. Any example we give that uses a loop will be type-unstable for Zygote

@compact should be able to deal with this in a type-stable manner (evidently, this point is an actual bug more than just a documentation issue)

Not really. Proving that the outputs are invariant to any kind of transformation on the inputs is extremely challenging, and not possible at the macro level.

@jesseylin
Copy link
Author

I was not aware of this limitation in Zygote. Evidently still the solution is still to either change the documentation to use Enzyme instead or mark this specific usage in the example as type unstable.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants