Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Backprop through time #648

Closed
MikeInnes opened this issue Feb 27, 2019 · 8 comments
Closed

Backprop through time #648

MikeInnes opened this issue Feb 27, 2019 · 8 comments

Comments

@MikeInnes
Copy link
Member

Continuing our series "cool things we can't have yet", and inspired by this comment I was thinking about how we'll expose BPTT. Currently, given a forward pass like this:

for word in seq
  loss += model(word)
end
loss

If we don't want to backprop over the whole sequence at once (gradient outside the loop) or over only a single step at a time (gradient inside the loop) then we need to split the loop as follows:

for chunks in seq
  θ̄ = gradient() do
    for word in chunk
      loss += model(word)
    end
    loss
  end
end

An alternative to this is to just expose primitives that let us fiddle with time steps directly. Consider:

record = Capacitor(5)
for x in seq
  θ̄ = gradient() do
    record() do
      model(x)
    end
  end
end

Alright, bear with me here. This is written as if we were backprop-ing only across a single time step at a time, but with model evaluation wrapped in record. The idea is that record will log 5 previous backpropagators for the closure it is passed, and then chain these together for the backwards pass, which means we can actually backpropagate through n previous iterations of the loop -- i.e. backpropagation through time.

What's cool about this is that it makes BPTT completely orthogonal to the structure of the forward pass. The recorder can equally well be set up to backprop the last n steps each iteration (sliding window BTTF) or only every nth iteration (normal BTTF), or anything in between, and this can be set up differently for different parts of the model. It also isn't specific to any particular RNN implementation, e.g. this will work even though we have to backprop through h over loop iterations:

record = Capacitor(5)
h = ...
for word in seq
  θ̄ = gradient() do
    record() do
      y, h = model(word)
      loss(word, y)
    end
  end
end

The main question is whether this is actually going to be intuitive for people (who aren't travelling at 88mph). If it looks weird right now I think that's partly because we're not used to using gradient this way, so getting used to that will make the extra feature easier to reason about. At least for sliding windows, I think it's strictly better than flow-based alternatives.

@philtomson
Copy link

What about the train_step! proposed at: #607 (comment) ?

I'm trying to translate some PyTorch that does BPTT and in the code I'm translating (ENAS-Pytorch) they seem to have put in an explicit loop inside their forward function to do the time steps (35 time steps in this case).

I suppose that's possible in with Flux's train! function: you could put the train! call inside of a loop that counts to the number of time steps you want for BPTT and give it a batch of data in each iteration - would that work?

@MikeInnes
Copy link
Member Author

Yeah, right now the Flux approach to this is essentially the same as PyTorch. step! and train! should basically be orthogonal issues; if you want to mentally rewrite the examples you should basically be able to replace gradient with step!, or the outer loop and gradient with train!.

@jie-huangfu
Copy link

Hi, where's the code implementing BPTT? (in recurrent.jl ?) I wrote some code doing BPTT using plain julia, and would like to compare the intermediate results with Flux. The fprop is easy to understand, but i didn't find where the bprop is implemented.

Could someone pointing me to where to check? Thanks. (a google search landed me on this page, which seems the best place to ask)

@darsnack
Copy link
Member

There is no code directly in Flux.jl implementing BPTT. This is "just" calling gradient over the whole loop like Mike did in his first code snippet.

@jie-huangfu
Copy link

jie-huangfu commented Mar 11, 2023

does the "gradient" function accumulate gradients through time somehow? in case of seq2one, how does the function know how to loop? ideally, I would like to be able to intercept and verify the accumulation process.

for comparison, a manually written bptt has something like the following: Are there variables similar to "dWhh" kept somewhere? karpathy's code

for t in reversed(xrange(len(inputs))):
dy = np.copy(ps[t])
dy[targets[t]] -= 1
dWhy += np.dot(dy, hs[t].T)
dby += dy
dh = np.dot(Why.T, dy) + dhnext # backprop into h
dhraw = (1 - hs[t] * hs[t]) * dh # backprop through tanh nonlinearity
dbh += dhraw
dWxh += np.dot(dhraw, xs[t].T)
dWhh += np.dot(dhraw, hs[t-1].T)
dhnext = np.dot(Whh.T, dhraw)

@darsnack
Copy link
Member

darsnack commented Mar 12, 2023

Yes, the AD backend, Zygote, will handle gradient accumulation through the loop. See this comment for how you can implement many to many or many to one models. Also check the recurrent docs.

@jie-huangfu
Copy link

thanks for your reply, darsnack. I posted another question in issue 144, as that one seems more recent. Do you mind taking a look?

Yes, the AD backend, Zygote, will handle gradient accumulation through the loop. See this comment for how you can implement many to many or many to one models. Also check the recurrent docs.

@CarloLucibello
Copy link
Member

closing as old

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

5 participants