-
-
Notifications
You must be signed in to change notification settings - Fork 26
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
recurrent example for docs #144
Comments
I haven’t used the library for recurrent nets, so would be interested to see how this works and am open to changes of API if necessary 👍 |
I've messed around with it more since writing this. Recurrent nets seem to require a fair amount of dedicated code, I'm not sure if FluxTraining.jl would be the place for all of it. In particular, I've found myself needing to write functions to:
Additionally I wonder if the way sequences are stored is uniform in the ecosystem. The flux documentation itself strongly suggests that sequences should be nested arrays rather than rank-3 arrays. |
I have used Flux + FluxTraining quite a bit for recurrent models in the past. In general, you shouldn't need to do anything special. Most of the work is related primarily to Flux and how it expects the data. Here is the situation I almost always end up in, and it might be useful for you.
In general as long as you think of a single sample in your dataset as a single sequence, then you can adapt the steps above to get them into the sequence of batches (samples) that Flux wants. From there, achieving the different tasks is all in the loss function. # seq to seq prediction
function seq2seq_loss(loss_fn)
function _loss(m, xs, ys)
yhats = [m(xi) for xi in xs]
return mean(loss_fn(yhat, yi) for (yhat, yi) in zip(yhats, yi))
end
return _loss
end
# seq to one prediction
function seq2one_loss(loss_fn)
function _loss(m, xs, ys)
yhats = [m(xi) for xi in xs]
return loss_fn(yhats[end], ys[end])
end
return _loss
end
# samplers for mapping the previous token to the next token
# used below in sample_model
sample_softmax(y::AbstractVector) =
Flux.onehot(rand(Categorical(softmax(y))), 1:length(y))
function sample_softmax(y::AbstractMatrix)
ŷs = [rand(Categorical(y)) for y in eachcol(softmax(ys))]
return Flux.onehotbatch(ŷs, 1:size(y, 1))
end
sample_best(ys::AbstractVecOrMat) = Flux.onehot(argmax(ys; dims = 1), 1:size(ys, 1))
# recurrently predict a sequence given a primer input sequence
function sample_model(model, nseq, primer = [], sampler = identity)
Flux.reset!(model)
tokens = [model(x) for x in primer]
ncurrent = length(tokens)
while ncurrent < nseq
nexttoken = model(sampler(last(tokens)))
push!(tokens, nexttoken)
ncurrent += 1
end
return tokens
end Note that batching does not affect any of the functions above. As long as you get the "sequence of batches" format right, you should be good. If you still want to express all this using FluxTraining, then the following is something I've used in the past. get_inout_seq(xs::AbstractVector) = xs[1:(end - 1)], xs[2:end]
get_inout_seq(xs::NTuple{2}) = xs[1], xs[2]
struct BPTTTrainingPhase <: AbstractTrainingPhase end
function FluxTraining.step!(learner, phase::BPTTTrainingPhase, batch)
xs, ys = get_inout_seq(batch)
FluxTraining.runstep(learner, phase, (xs = xs, ys = ys)) do handle, state
Flux.reset!(learner.model)
state.grads = gradient(learner.params) do
state.ŷs = [learner.model(xi) for xi in state.xs]
handle(FluxTraining.LossBegin())
state.loss = learner.lossfn(state.ŷs, state.ys)
handle(FluxTraining.BackwardBegin())
return state.loss
end
handle(FluxTraining.BackwardEnd())
Flux.update!(learner.optimizer, learner.params, state.grads)
end
end
struct BPTTValidationPhase <: AbstractValidationPhase
nfeedback::Int
sampler
end
BPTTValidationPhase() = BPTTValidationPhase(0, identity)
BPTTValidationPhase(nfeedback) = BPTTValidationPhase(nfeedback, identity)
function FluxTraining.step!(learner, phase::BPTTValidationPhase, batch)
xs, ys = get_inout_seq(batch)
FluxTraining.runstep(learner, phase, (xs = xs, ys = ys)) do _, state
Flux.reset!(learner.model)
n = length(state.xs) - phase.nfeedback
# n steps where input drives model
state.ŷs = [learner.model(state.xs[i]) for i in 1:n]
# nfeedback steps where the model drives itself
for _ in (n + 1):length(state.xs)
ŷ = phase.sampler(state.ŷs[end])
push!(state.ŷs, learner.model(ŷ))
end
state.loss = learner.lossfn(state.ŷs, state.ys)
end
end I don't need to do this for training recurrent models, but I found it nice for a particular project where BPTT was the thing I was comparing against. Specifically, |
If your data is already in a big rank-3 array, then you can make your axis order as Otherwise, I find the approach of treating each sample as a self-contained time series is the most intuitive and compatible with existing data wrangling/loading packages like MLUtils.jl. Just remember to |
Note that we actually do support 3D arrays of shape |
Note I edited my comments from the original to correct a mistake in the order of the axis dimensions. Clearly, the time I've been spending with Jax recently is leaking... |
i'm trying to understand how Zygote does the gradient accumulation, in case of a RNN. In the following I'm comparing the result with a manual gradient accumulation, and the result is different. What could be the reason here? The code is self-contained and runnable. using Flux
using Random
Random.seed!(149)
# x in format (feature, samples, timesteps)
x = reshape([0.84147096, 0.9092974, 0.14112], 1, 1, 3)
y = -0.7568025
layer1 = Flux.Recur(Flux.RNNCell(1 => 5, tanh))
layer2 = Flux.Dense( 5 => 1 )
model = Flux.Chain(layer1, layer2)
Flux.reset!(model)
e, g = Flux.withgradient(model, x, y) do m, xi, yi
yhat = [m(xi[:,:,i]) for i in 1:3] # timesteps = 3
return Flux.mse(yhat[3], yi)
end
println("flux gradient dWx: ", g[1][1][1].cell.Wi)
#-------- get individual gradients at each step -----------------
c1 = deepcopy(layer1.cell)
c2 = deepcopy(c1)
c3 = deepcopy(c2)
h0 = zeros(5, 1) # initial state zero
e3, f = Flux.withgradient(c1.Wi, c2.Wi, c3.Wi,
c1.Wh, c2.Wh, c3.Wh,
c1.b, c2.b, c3.b) do Wi1, Wi2, Wi3, Wh1, Wh2, Wh3, b1, b2, b3
h1 = tanh.( Wi1 * x[:,:,1] + Wh1 * h0 + b1); y1 = layer2(h1)
h2 = tanh.( Wi2 * x[:,:,2] + Wh2 * h1 + b2); y2 = layer2(h2)
h3 = tanh.( Wi3 * x[:,:,3] + Wh3 * h2 + b3); y3 = layer2(h3)
Flux.mse(y3, y)
end
println("accumulated dWx: ", f[1]+f[2]+f[3]) |
Motivation and description
Dealing with recurrent networks presents a lot of questions because it works rather differently from the stateless case.
I think it would be extremely helpful to have explicit examples: one for sequence-to-sequence and one for sequence-to-one.
Possible Implementation
I might come back and contribute this, but as I'm posting this I still don't think I'm doing this the intended way...
The text was updated successfully, but these errors were encountered: