Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

recurrent example for docs #144

Open
ExpandingMan opened this issue Feb 24, 2023 · 7 comments
Open

recurrent example for docs #144

ExpandingMan opened this issue Feb 24, 2023 · 7 comments

Comments

@ExpandingMan
Copy link
Contributor

Motivation and description

Dealing with recurrent networks presents a lot of questions because it works rather differently from the stateless case.

I think it would be extremely helpful to have explicit examples: one for sequence-to-sequence and one for sequence-to-one.

Possible Implementation

I might come back and contribute this, but as I'm posting this I still don't think I'm doing this the intended way...

@lorenzoh
Copy link
Member

I haven’t used the library for recurrent nets, so would be interested to see how this works and am open to changes of API if necessary 👍

@ExpandingMan
Copy link
Contributor Author

I've messed around with it more since writing this. Recurrent nets seem to require a fair amount of dedicated code, I'm not sure if FluxTraining.jl would be the place for all of it. In particular, I've found myself needing to write functions to:

  • Predict sequence-to-sequence.
  • Predict sequence-to-one.
  • Recurrently predict future of a sequence after a seed.
  • Each of these for batches.

Additionally I wonder if the way sequences are stored is uniform in the ecosystem. The flux documentation itself strongly suggests that sequences should be nested arrays rather than rank-3 arrays.

@darsnack
Copy link
Member

darsnack commented Feb 28, 2023

I have used Flux + FluxTraining quite a bit for recurrent models in the past. In general, you shouldn't need to do anything special. Most of the work is related primarily to Flux and how it expects the data. Here is the situation I almost always end up in, and it might be useful for you.

  1. There is a function generate(T) that creates a D x T matrix of where T is time and D is the feature dimension.
  2. I can generate a vector of samples as samples = [generate(T) for _ in 1:nsamples].
  3. I use Flux.batchseq(samples) to turn this into a sequence of batches (from a batch of sequences). This is the key step.
  4. You can generate many batches repeating the above steps with nsamples = batch_size for nbatches iterations.

In general as long as you think of a single sample in your dataset as a single sequence, then you can adapt the steps above to get them into the sequence of batches (samples) that Flux wants.

From there, achieving the different tasks is all in the loss function.

# seq to seq prediction
function seq2seq_loss(loss_fn)
    function _loss(m, xs, ys)
         yhats = [m(xi) for xi in xs]
         return mean(loss_fn(yhat, yi) for (yhat, yi) in zip(yhats, yi))
    end

    return _loss
end

# seq to one prediction
function seq2one_loss(loss_fn)
    function _loss(m, xs, ys)
         yhats = [m(xi) for xi in xs]
         return loss_fn(yhats[end], ys[end])
    end

    return _loss
end

# samplers for mapping the previous token to the next token
# used below in sample_model
sample_softmax(y::AbstractVector) =
    Flux.onehot(rand(Categorical(softmax(y))), 1:length(y))
function sample_softmax(y::AbstractMatrix)
    ŷs = [rand(Categorical(y)) for y in eachcol(softmax(ys))]

    return Flux.onehotbatch(ŷs, 1:size(y, 1))
end

sample_best(ys::AbstractVecOrMat) = Flux.onehot(argmax(ys; dims = 1), 1:size(ys, 1))

# recurrently predict a sequence given a primer input sequence
function sample_model(model, nseq, primer = [], sampler = identity)
    Flux.reset!(model)
    tokens = [model(x) for x in primer]
    ncurrent = length(tokens)
    while ncurrent < nseq
        nexttoken = model(sampler(last(tokens)))
        push!(tokens, nexttoken)
        ncurrent += 1
    end

    return tokens
end

Note that batching does not affect any of the functions above. As long as you get the "sequence of batches" format right, you should be good.

If you still want to express all this using FluxTraining, then the following is something I've used in the past.

get_inout_seq(xs::AbstractVector) = xs[1:(end - 1)], xs[2:end]
get_inout_seq(xs::NTuple{2}) = xs[1], xs[2]

struct BPTTTrainingPhase <: AbstractTrainingPhase end

function FluxTraining.step!(learner, phase::BPTTTrainingPhase, batch)
    xs, ys = get_inout_seq(batch)
    FluxTraining.runstep(learner, phase, (xs = xs, ys = ys)) do handle, state
        Flux.reset!(learner.model)
        state.grads = gradient(learner.params) do
            state.ŷs = [learner.model(xi) for xi in state.xs]
            handle(FluxTraining.LossBegin())
            state.loss = learner.lossfn(state.ŷs, state.ys)

            handle(FluxTraining.BackwardBegin())
            return state.loss
        end
        handle(FluxTraining.BackwardEnd())
        Flux.update!(learner.optimizer, learner.params, state.grads)
    end
end

struct BPTTValidationPhase <: AbstractValidationPhase
    nfeedback::Int
    sampler
end
BPTTValidationPhase() = BPTTValidationPhase(0, identity)
BPTTValidationPhase(nfeedback) = BPTTValidationPhase(nfeedback, identity)

function FluxTraining.step!(learner, phase::BPTTValidationPhase, batch)
    xs, ys = get_inout_seq(batch)
    FluxTraining.runstep(learner, phase, (xs = xs, ys = ys)) do _, state
        Flux.reset!(learner.model)
        n = length(state.xs) - phase.nfeedback
        # n steps where input drives model
        state.ŷs = [learner.model(state.xs[i]) for i in 1:n]
        # nfeedback steps where the model drives itself
        for _ in (n + 1):length(state.xs)
            ŷ = phase.sampler(state.ŷs[end])
            push!(state.ŷs, learner.model(ŷ))
        end
        state.loss = learner.lossfn(state.ŷs, state.ys)
    end
end

I don't need to do this for training recurrent models, but I found it nice for a particular project where BPTT was the thing I was comparing against. Specifically, BPTTValidationPhase is nice for allowing evaluating models in the recurrently driven mode where they feed their own input.

@darsnack
Copy link
Member

darsnack commented Feb 28, 2023

If your data is already in a big rank-3 array, then you can make your axis order as feature x samples x time, and use Base.Iterators or MLUtils.jl to partition this along second axis into a vector of feature x batch x time chunks. A Recur model in Flux should consume these chunks correctly.

Otherwise, I find the approach of treating each sample as a self-contained time series is the most intuitive and compatible with existing data wrangling/loading packages like MLUtils.jl. Just remember to batchseq before passing to the Flux model.

@ToucheSir
Copy link
Member

Additionally I wonder if the way sequences are stored is uniform in the ecosystem. The flux documentation itself strongly suggests that sequences should be nested arrays rather than rank-3 arrays.

Note that we actually do support 3D arrays of shape (features, batch, timesteps) as inputs to RNN layers. The reason it's not documented/advertised is we're not sure whether the API makes sense. For example, how do you differentiate between a batched sequence input to a normal RNN and one timestep of input to a conv-based RNN? The current implementation also does the same partitioning by timesteps you'd do by hand internally, so it should be slower than Kyle's suggestion above.

@darsnack
Copy link
Member

darsnack commented Mar 1, 2023

Note I edited my comments from the original to correct a mistake in the order of the axis dimensions. Clearly, the time I've been spending with Jax recently is leaking...

@jie-huangfu
Copy link

i'm trying to understand how Zygote does the gradient accumulation, in case of a RNN. In the following I'm comparing the result with a manual gradient accumulation, and the result is different. What could be the reason here? The code is self-contained and runnable.

using Flux 
using Random
Random.seed!(149)

# x in format (feature, samples, timesteps)
x = reshape([0.84147096, 0.9092974, 0.14112], 1, 1, 3)
y = -0.7568025

layer1 = Flux.Recur(Flux.RNNCell(1 => 5, tanh))
layer2 = Flux.Dense( 5 => 1 )
model = Flux.Chain(layer1, layer2)

Flux.reset!(model)
e, g = Flux.withgradient(model, x, y) do m, xi, yi
    yhat = [m(xi[:,:,i]) for i in 1:3]    # timesteps = 3
    return Flux.mse(yhat[3], yi)
end
println("flux gradient dWx: ", g[1][1][1].cell.Wi)

#-------- get individual gradients at each step -----------------
c1 = deepcopy(layer1.cell)
c2 = deepcopy(c1)
c3 = deepcopy(c2)

h0 = zeros(5, 1)  # initial state zero 
e3, f = Flux.withgradient(c1.Wi, c2.Wi, c3.Wi, 
    c1.Wh, c2.Wh, c3.Wh, 
    c1.b, c2.b, c3.b) do Wi1, Wi2, Wi3,   Wh1, Wh2, Wh3,  b1, b2, b3

    h1 = tanh.( Wi1 * x[:,:,1] + Wh1 * h0 + b1);  y1 = layer2(h1)
    h2 = tanh.( Wi2 * x[:,:,2] + Wh2 * h1 + b2);  y2 = layer2(h2)
    h3 = tanh.( Wi3 * x[:,:,3] + Wh3 * h2 + b3);  y3 = layer2(h3)

    Flux.mse(y3, y)
end
println("accumulated dWx:   ", f[1]+f[2]+f[3])

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

5 participants