You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
The following example from the API documentation (copy-pasted for reference), which uses w2 = [Dense(128, 128) for i in 1:nlayers]
as opposed to w2 = Chain([Dense(128, 128) for i in 1:nlayers])
gives Zygote gradients which are type unstable.
Either
the example should be updated
the example should come with a warning that this is type unstable
@compact should be able to deal with this in a type-stable manner (evidently, this point is an actual bug more than just a documentation issue)
julia> n_in = 1;
julia> n_out = 1;
julia> nlayers = 3;
julia> model = @compact(w1=Dense(n_in, 128),
w2=[Dense(128, 128) for i in 1:nlayers], w3=Dense(128, n_out), act=relu) do x
embed = act.(w1(x))
for w in w2
embed = act.(w(embed))
end
out = w3(embed)
@return out
end
@compact(
w1 = Dense(1 => 128), # 256 parameters
w2 = NamedTuple(
1 = Dense(128 => 128), # 16_512 parameters
2 = Dense(128 => 128), # 16_512 parameters
3 = Dense(128 => 128), # 16_512 parameters
),
w3 = Dense(128 => 1), # 129 parameters
act = relu,
) do x
embed = act.(w1(x))
for w = w2
embed = act.(w(embed))
end
out = w3(embed)
return out
end # Total: 49_921 parameters,
# plus 1 states.
julia> ps, st = Lux.setup(Xoshiro(0), model);
julia> size(first(model(randn(n_in, 32), ps, st))) # 1×32 Matrix as output.
(1, 32)
julia> using Optimisers, Zygote
julia> x_data = collect(-2.0f0:0.1f0:2.0f0)';
julia> y_data = 2 .* x_data .- x_data .^ 3;
julia> optim = Optimisers.setup(Adam(), ps);
julia> loss_initial = sum(abs2, first(model(x_data, ps, st)) .- y_data);
julia> for epoch in 1:1000
loss, gs = Zygote.withgradient(
ps -> sum(abs2, first(model(x_data, ps, st)) .- y_data), ps)
Optimisers.update!(optim, ps, gs[1])
end;
julia> loss_final = sum(abs2, first(model(x_data, ps, st)) .- y_data);
julia> loss_initial > loss_final
true
The text was updated successfully, but these errors were encountered:
It is Zygote, that is inherently type-unstable. We make a lot of effort to ensure some of the containers are type-stable for Zygote, but no guarantees whatsoever in that regard are given.
the example should be updated
That example shows "see you can pass in a list or tuple of layers". The forward pass will be type-stable and so will gradients if computed using something like Enzyme. Any example we give that uses a loop will be type-unstable for Zygote
@compact should be able to deal with this in a type-stable manner (evidently, this point is an actual bug more than just a documentation issue)
Not really. Proving that the outputs are invariant to any kind of transformation on the inputs is extremely challenging, and not possible at the macro level.
I was not aware of this limitation in Zygote. Evidently still the solution is still to either change the documentation to use Enzyme instead or mark this specific usage in the example as type unstable.
The following example from the API documentation (copy-pasted for reference), which uses
w2 = [Dense(128, 128) for i in 1:nlayers]
as opposed to
w2 = Chain([Dense(128, 128) for i in 1:nlayers])
gives Zygote gradients which are type unstable.
Either
@compact
should be able to deal with this in a type-stable manner (evidently, this point is an actual bug more than just a documentation issue)The text was updated successfully, but these errors were encountered: