On RNNs

This is part #8 of my notes on CS231n. The course is openly available, including the video lectures and assignments.

These notes are based on this lecture by Zane Durante.


Recurrent neural networks are the result of looping over a NN, and is very simply understood with this unrolled RNN diagram:

xRNNyttxRNNyh 011h 1xRNNy22h 2xRNNy33h 3...

In order to be able to loop a NN while preserving context we introduce the hidden state, which carries information of what has been processed so far. This new output is what makes it possible to daisy chain it.

In math notation:

(1)ht=fW(ht1,xt)(1)\quad h_t = f_W(h_{t-1}, x_t) (2)yt=fWhy(ht)(2)\quad y_t = f_{W_{hy}}(h_t)

where:

  • hth_t is the new internal state at time step tt
  • ht1h_{t-1} is the previous internal state
  • fWf_W is some function with parameters WW
  • yty_t is the output at time step tt
  • fWhyf_{W_{hy}} is some function with parameters WhyW_{hy}

The important bit is WhyW_{hy} is a learnable weight matrix different from WW. But these two matrices are reused at each sequence step!

That makes RNNs super flexible:

many-to-manymany-to-manymany-to-oneone-to-many
  • e.g. one-to-many: image captioning, given an image produce a sequence of words.
  • e.g. many-to-one: video action prediction, given a sequence of video frames predict the action class
  • e.g. many-to-many video captioning, sequence of frames to sequence of words
  • e.g. many-to-many video classification on a frame level

Because of the different variants, learning the weight matrices is problem specific, for instance, in a many to many scenario, could look like:

xxxfffLLLLLWhhhhhyyyyWWW0123t123123t123t...

You compute the Loss at each step yy and you sum them all together to get the total loss. Then in backprop you can compute the gradients at each step tt, then sum it all together.

A many-to-one could look like:

xxxfffLWhhhhhyWWW0123t123t...

You have a single loss at the end of the sequence.

There are some nuances in how you actually compute the backprop in practice, especially for large sequences. To compute the loss at each step you have to keep the intermediate activations from each step so that gradients can later be computed during backward pass. For very long input sequences you will run out of memory:

L

So to practically resolve this issue you can do truncated backprop through time, that means running forward and backward through chunks of the sequence instead of the whole sequence:

LLLL

In truncated backpropagation through time we run forward and backward through chunks instead of the whole sequence while carrying the hidden states forward through chunks. Locally treating each chunk as if it was the entire sequence, for each chunk h0h_0 is initialised from the previous chunk hidden state but the gradients are not carried over.

You might wonder how it works for many-to-one task, it works slightly different:

L

In this setting, the final loss only backpropagates through the last chunk. The previous chunks influence the result through the carried hidden state, but gradients are detached at the boundaries. This tradeoff makes training feasible for long sequences but it limits how far back in time the model can reward/penalize responsability for the final prediction.


In practice most useful RNNs are in fact Multilayer RNNs where you stack hidden layers each with its own set of weights:


The vanishing gradient problem

At each time step, a vanilla RNN takes the previous hidden state ht1h_{t-1} and the current input xtx_t, mixes them linearly, then passes the result through a nonlinearity:

Whyhxt-1ttt-1stacktanh
ht=tanh(Whhht1+Wxhxt)h_t = \tanh(W_{hh}h_{t-1} + W_{xh}x_t) ht=tanh((WhhWxh)(ht1xt))h_t = \tanh\left( \begin{pmatrix} W_{hh} & W_{xh} \end{pmatrix} \begin{pmatrix} h_{t-1} \\ x_t \end{pmatrix} \right) ht=tanh(W(ht1xt))h_t = \tanh\left(W \begin{pmatrix} h_{t-1} \\ x_t \end{pmatrix}\right)

The issue shows up during backpropagation through time. To update the parameters, the gradient must flow backward through many recurrent steps. At each step, that gradient is multiplied again by the recurrent weights and by the derivative of the activation function:

Whyhxt-1tt1stacktanh

In vanilla RNNs, the activation function is usually tanh\tanh, whose derivative is bounded by 1 and is often much smaller. Since the gradient is multiplied by these terms repeatedly across time, it tends to shrink and eventually vanish.

WWWhyyyhhhxxx0123123123stackstackstacktanhtanhtanh

LSTM

LSTMs (Hochreiter & Schmidhuber, 1997) were introduced to address this. The key idea is to keep a cell state ctc_t that gives information a more stable path through time, while using gates to control what information is written, erased, and exposed.

An LSTM computes four vectors from the previous hidden state and the current input:

  • iti_t: the input gate controls how much new information is written
  • ftf_t: the forget gate controls how much of the previous cell state is kept
  • oto_t: the output gate controls how much of the cell state is exposed as hidden state
  • gtg_t: the gate gate controls what you are writting to the hidden state cell
Whcxt-1t-1cthtt-1stackfigo+tanh
(itftotgt)=(σσσtanh)W(ht1xt)\begin{pmatrix} i_t\\ f_t\\ o_t\\ g_t \end{pmatrix}= \begin{pmatrix} \sigma\\ \sigma\\ \sigma\\ \tanh \end{pmatrix}W \begin{pmatrix} h_{t - 1} \\ x_t \end{pmatrix} ct=ftct1+itgtc_t = f_t \odot c_{t-1} + i_t \odot g_t ht=ottanh(ct)h_t = o_t \odot \tanh(c_t)

The crucial part is the cell-state update:

ct=ftct1+itgtc_t = f_t \odot c_{t-1} + i_t \odot g_t

Somewhat analogous to how residual connections help gradient flow in ResNets, LSTMs mitigate the vanishing gradients problem, by creating a direct path with uninterrupted gradient flow:

Whcxt-1t-1cthtt-1stackfigo+tanh

That allows gradients to flow across many more time steps without vanishing:

WWWhhhhccccxxx01230123ctctct012stackstackstackfffiiigggooo+++tanhtanhtanh

RNNs were among the first neural architectures to work well on sequence generation tasks such as image captioning, where a pretrained CNN encodes the image into a feature vector, and that representation is used to initialize the RNN hidden state. The decoder then starts from a special start token and generates one word at a time, feeding each predicted token back into the model until it emits an end token.

They were also very successful in character level language modeling, where they learned to generate text one character at a time. A well-known example is Karpathy et al., 2015, which showed that recurrent networks could produce coherent character sequences and that some individual hidden units learned interpretable behaviors, such as tracking quotes, brackets, indentation and other long-range textual structure.

RNN based models were also successfully used in multimodal tasks such as Visual Question Answering, Visual Dialog, and Vision-and-Language Navigation.

To me, there is something deeply fascinating about that idea of an internal state. It is a useful analogy for how we humans move through the world, absorbing sensory information and updating an internal, temporally aware understanding of physical space. That is also why the connection to world models is so compelling to me, since they too must maintain a latent state of the world as it changes through time. But that is for another day.

← Back to blog