On RNNs
This is part #8 of my notes on CS231n. The course is openly available, including the video lectures and assignments.
These notes are based on this lecture by Zane Durante.
Recurrent neural networks are the result of looping over a NN, and is very simply understood with this unrolled RNN diagram:
In order to be able to loop a NN while preserving context we introduce the hidden state, which carries information of what has been processed so far. This new output is what makes it possible to daisy chain it.
In math notation:
where:
- is the new internal state at time step
- is the previous internal state
- is some function with parameters
- is the output at time step
- is some function with parameters
The important bit is is a learnable weight matrix different from . But these two matrices are reused at each sequence step!
That makes RNNs super flexible:
- e.g. one-to-many: image captioning, given an image produce a sequence of words.
- e.g. many-to-one: video action prediction, given a sequence of video frames predict the action class
- e.g. many-to-many video captioning, sequence of frames to sequence of words
- e.g. many-to-many video classification on a frame level
Because of the different variants, learning the weight matrices is problem specific, for instance, in a many to many scenario, could look like:
You compute the Loss at each step and you sum them all together to get the total loss. Then in backprop you can compute the gradients at each step , then sum it all together.
A many-to-one could look like:
You have a single loss at the end of the sequence.
There are some nuances in how you actually compute the backprop in practice, especially for large sequences. To compute the loss at each step you have to keep the intermediate activations from each step so that gradients can later be computed during backward pass. For very long input sequences you will run out of memory:
So to practically resolve this issue you can do truncated backprop through time, that means running forward and backward through chunks of the sequence instead of the whole sequence:
In truncated backpropagation through time we run forward and backward through chunks instead of the whole sequence while carrying the hidden states forward through chunks. Locally treating each chunk as if it was the entire sequence, for each chunk is initialised from the previous chunk hidden state but the gradients are not carried over.
You might wonder how it works for many-to-one task, it works slightly different:
In this setting, the final loss only backpropagates through the last chunk. The previous chunks influence the result through the carried hidden state, but gradients are detached at the boundaries. This tradeoff makes training feasible for long sequences but it limits how far back in time the model can reward/penalize responsability for the final prediction.
In practice most useful RNNs are in fact Multilayer RNNs where you stack hidden layers each with its own set of weights:
The vanishing gradient problem
At each time step, a vanilla RNN takes the previous hidden state and the current input , mixes them linearly, then passes the result through a nonlinearity:
The issue shows up during backpropagation through time. To update the parameters, the gradient must flow backward through many recurrent steps. At each step, that gradient is multiplied again by the recurrent weights and by the derivative of the activation function:
In vanilla RNNs, the activation function is usually , whose derivative is bounded by 1 and is often much smaller. Since the gradient is multiplied by these terms repeatedly across time, it tends to shrink and eventually vanish.
LSTM
LSTMs (Hochreiter & Schmidhuber, 1997) were introduced to address this. The key idea is to keep a cell state that gives information a more stable path through time, while using gates to control what information is written, erased, and exposed.
An LSTM computes four vectors from the previous hidden state and the current input:
- : the input gate controls how much new information is written
- : the forget gate controls how much of the previous cell state is kept
- : the output gate controls how much of the cell state is exposed as hidden state
- : the gate gate controls what you are writting to the hidden state cell
The crucial part is the cell-state update:
Somewhat analogous to how residual connections help gradient flow in ResNets, LSTMs mitigate the vanishing gradients problem, by creating a direct path with uninterrupted gradient flow:
That allows gradients to flow across many more time steps without vanishing:
RNNs were among the first neural architectures to work well on sequence generation tasks such as image captioning, where a pretrained CNN encodes the image into a feature vector, and that representation is used to initialize the RNN hidden state. The decoder then starts from a special start token and generates one word at a time, feeding each predicted token back into the model until it emits an end token.
They were also very successful in character level language modeling, where they learned to generate text one character at a time. A well-known example is Karpathy et al., 2015, which showed that recurrent networks could produce coherent character sequences and that some individual hidden units learned interpretable behaviors, such as tracking quotes, brackets, indentation and other long-range textual structure.
RNN based models were also successfully used in multimodal tasks such as Visual Question Answering, Visual Dialog, and Vision-and-Language Navigation.
To me, there is something deeply fascinating about that idea of an internal state. It is a useful analogy for how we humans move through the world, absorbing sensory information and updating an internal, temporally aware understanding of physical space. That is also why the connection to world models is so compelling to me, since they too must maintain a latent state of the world as it changes through time. But that is for another day.