These notes build on top of RNNs. They introduce the attention mechanism as it first emerged in recurrent encoder decoder architectures.
Below we have an encoder RNN that processes a sequence x1,…xi as input hi=fW(xi,hi−1), we take the last hidden state hi and call it context c=hi.
Then the decoder RNN st=gU(yt−1,st−1,c) produces the output sequence y1,…yt.
As the input sequence grows the context c becomes a bottleneck, there is a limit to what it can "compress".
To address this problem, attention starts to emerge. By asking "what if we could look back at the entire sequence for each step t of the output sequence?"
And perhaps more concretely what if we could recompute the context vector c at every step of the output sequence, by selectively attending to the input that conditions the output.
To look back at the input sequence, we compute a scalar value for each step of the input sequence that tells us how related each encoder hidden state step is to the decoder hidden state. Mathematically: et,i=fatt(st−1,hi), where fatt could be a linear layer.
We apply softmax to convert that into a probability distribution a, i.e. each is between 0 and 1 and they sum to 1.
And take the weighted sum of the hidden states ct=∑iat,ihi to compute the context vector ct.
This way we have a different c at each step t.
Just like before, ct is used to compute the next decoder hidden state st=gU(yt−1,st−1,ct). gU here is any RNN unit (e.g. LSTM, GRU ...)
With this mechanism, and via gradient descent, the network learns to make the context vector attend to the relevant part of the input sequence. (Bahdanau, Cho & Bengio, 2014)
At step 2, note we use s1 instead of s0 to attend the input sequence and compute c2:
We repeat this process for each step of the output sequence.
This removes the single-vector bottleneck and on top, at each timestep of the decoder, the context vector attends different parts of the input sequence.
This ability to attend selectively is very powerful, let's abstract it into its own primitive and cut the RNN out.
Let's put some names on what the attention mechanism is actually doing:
Let's call data vectorsX∈RNX×DQ what used to be encoder RNN states h.
Let's call query vectorsQ∈RNQ×DQ what used to be decoder RNN states s.
Let's call output vectorsY∈RNQ×DX what used to be our context vector c. Y is computed
Y=AX
Yi=∑jAijXj
as a result of:
Computing data-query similarities, what used to be (e):
E=QXT/DQ∈RNQ×NX
Eij=QiXj/DQ
Computing the Attention weights probability distribution, used to be a:
A=softmax(E)∈RNQ×NX where each row of A sums to 1
Because the data vectors X are being used both for computing the query similarities E and to provide the content that gets aggregated into the output Y.
We project X into:
KeysK=XWK∈RNX×DQ which we will use to compute the similarities:
E=QKT/DQ∈RNQ×NX
ValuesV=XWV∈RNX×DV used in the output vector:
Y=AV∈RNQ×DV
This let's the model learn one representation for where to attend (keys) and another for what information to return (values).
WK∈RDX×DQ and WV∈RDX×DV are another set of learnable matrices.
With the RNN stripped and these new naming convention the Cross-Attention Layer looks like:
Q is the given query vector input
X is the given data vector input
K=XWK
V=XWV
E=QKT/DQ
A=softmax(E,dim=1)
Y=AV
Another variant is a Self-Attention Layer, where the only input is X:
X∈RN×Din is the given data vector input
Q=XWQ∈RN×Dout, WQ∈RDin×Dout being a new learnable matrix.
K=XWK∈RN×Dout, WK∈RDin×Dout
V=XWV∈RN×Dout, WV∈RDin×Dout
E=QKT/DQ∈RN×N
A=softmax(E,dim=1)∈RN×N
Y=AV∈RN×Dout
In practice we usually use multi head attention, where the model computes several attention heads (H) in parallel, each with its own learned projections:
X∈RN×D
Q=XWQ∈RH×N×DH, WQ∈RD×HDH
K=XWK∈RH×N×DH, WK∈RD×HDH
V=XWV∈RH×N×DH, WV∈RD×HDH
E=QKT/DQ∈RH×N×N
A=softmax(E,dim=1)∈RH×N×N
Y=AV∈RN×HDH
O=YWO∈RN×D, WO∈RHDH×D being a new learnable matrix to fuse the output of each head.
DH being the head dimension and usually DH=D/H
e.g. H = 3
Which as a great surprise, it can all be computed with 4 big matmuls:
QKV Projection
[Q∣K∣V]=X[WQ∣WK∣WV]
[Q∣K∣V]∈RN×3HDH=(N×D)(D×3HDH)
QK Similarity
E=QKT
E∈RH×N×N=(H×N×DH)(H×DH×N)
V-Weighting
Y=AV
Y∈RH×N×DH=(H×N×N)(H×N×DH)
Reshape to Y∈RN×HDH
Output Projection
O=YWO
O∈RN×D=(N×HDH)(HDH×D)
The Transformer
A Transformer block first applies multi-head self-attention then adds a residual connection and layer-normalizes; it then applies a position-wise feed-forward network (FFN/MLP) independently to each vector, followed by another residual addition and a final layer normalization:
A Transformer(Vaswani et al., 2017) is built by stacking multiple Transformer blocks sequentially:
That is the post-LN (LayerNorm after sublayer) description. Many modern GPT-style models are pre-LN (LayerNorm before sublayer)
Transformers are being successfully applied beyond text to images with ViT (Dosovitskiy et al., 2020), to video with ViViT, to multimodal image-text learning with CLIP, and also to audio, protein sequences, time series...
Since its introduction, it has been extended with many architectural and domain specific refinements, the core ideas presented here remain the essential conceptual starting point.
One variation worth mentioning is Mixture of Experts (MoE). Instead of a single dense feed-forward network in each block, we learn E different MLPs, called experts. A routing function sends each token to a subset of them.
This allows the model to scale total parameter count significantly while keeping per token compute much lower than activating all experts.
Other important refinements include causal masking, which restricts token t from attending to future tokens and is central to autoregressive decoders in the original Transformer, and Rotary Position Embeddings (RoPE), which encode positional information directly into attention.