On attention and transformers

This is part #9 of my notes on CS231n. The course is openly available, including the video lectures and assignments.

These notes are based on this lecture by Justin Johnson.


These notes build on top of RNNs. They introduce the attention mechanism as it first emerged in recurrent encoder decoder architectures.

Below we have an encoder RNN that processes a sequence x1,xix_1, \dots x_i as input hi=fW(xi,hi1)h_i = f_W(x_i, h_{i-1}), we take the last hidden state hih_i and call it context c=hic=h_i. Then the decoder RNN st=gU(yt1,st1,c)s_t = g_U(y_{t-1}, s_{t-1}, c) produces the output sequence y1,yty_1, \dots y_t.

yyysssyyy01t-112txh11xh22xh33xhii...12tsc[START]0...[STOP]

As the input sequence grows the context cc becomes a bottleneck, there is a limit to what it can "compress".

To address this problem, attention starts to emerge. By asking "what if we could look back at the entire sequence for each step tt of the output sequence?"

And perhaps more concretely what if we could recompute the context vector cc at every step of the output sequence, by selectively attending to the input that conditions the output.

To look back at the input sequence, we compute a scalar value for each step of the input sequence that tells us how related each encoder hidden state step is to the decoder hidden state. Mathematically: et,i=fatt(st1,hi)e_{t,i} = f_{att}(s_{t-1}, h_i), where fattf_att could be a linear layer. We apply softmax to convert that into a probability distribution aa, i.e. each is between 0 and 1 and they sum to 1. And take the weighted sum of the hidden states ct=iat,ihic_t = \sum_ia_{t,i}h_i to compute the context vector ctc_t. This way we have a different cc at each step tt. Just like before, ctc_t is used to compute the next decoder hidden state st=gU(yt1,st1,ct)s_t = g_U(y_{t-1}, s_{t-1}, c_t). gUg_U here is any RNN unit (e.g. LSTM, GRU ...)

ycsy011xh11xh22xh33xhtt...1[START]e a e a e a e a 1111121213131t1ts0softmaxXXXX+

With this mechanism, and via gradient descent, the network learns to make the context vector attend to the relevant part of the input sequence. (Bahdanau, Cho & Bengio, 2014)

At step 2, note we use s1s_1 instead of s0s_0 to attend the input sequence and compute c2c_2:

xh11xh22xh33xhii...e a e a e a e a 2121222223232i2is0softmaxXXXXyyccssyy01121212[START]+

We repeat this process for each step of the output sequence.

This removes the single-vector bottleneck and on top, at each timestep of the decoder, the context vector attends different parts of the input sequence.

This ability to attend selectively is very powerful, let's abstract it into its own primitive and cut the RNN out.

Let's put some names on what the attention mechanism is actually doing:

  • Let's call data vectors XRNX×DQX \in \mathbb{R}^{N_X\times D_Q} what used to be encoder RNN states hh.

  • Let's call query vectors QRNQ×DQQ \in \mathbb{R}^{N_Q\times D_Q} what used to be decoder RNN states ss.

  • Let's call output vectors YRNQ×DXY \in \mathbb{R}^{N_Q\times D_X} what used to be our context vector cc. YY is computed

    Y=AXY = AX

    Yi=jAijXjY_i = \sum_j A_{ij} X_j

    as a result of:

    • Computing data-query similarities, what used to be (ee):

      E=QXT/DQRNQ×NXE = QX^T /\sqrt{D_Q} \in \mathbb{R}^{N_Q\times N_X}

      Eij=QiXj/DQE_{ij} = Q_iX_j /\sqrt{D_Q}

    • Computing the Attention weights probability distribution, used to be aa:

      A=softmax(E)RNQ×NXA= softmax(E) \in \mathbb{R}^{N_Q\times N_X} where each row of A sums to 1

Because the data vectors XX are being used both for computing the query similarities EE and to provide the content that gets aggregated into the output YY. We project XX into:

  • Keys K=XWKRNX×DQK = XW_K \in \mathbb{R}^{N_X\times D_Q} which we will use to compute the similarities:

    E=QKT/DQRNQ×NXE = QK^T /\sqrt{D_Q} \in \mathbb{R}^{N_Q\times N_X}

  • Values V=XWVRNX×DVV = XW_V \in \mathbb{R}^{N_X\times D_V} used in the output vector:

    Y=AVRNQ×DVY = AV \in \mathbb{R}^{N_Q\times D_V}

This let's the model learn one representation for where to attend (keys) and another for what information to return (values). WKRDX×DQW_K \in \mathbb{R}^{D_X\times D_Q} and WVRDX×DVW_V \in \mathbb{R}^{D_X\times D_V} are another set of learnable matrices.

With the RNN stripped and these new naming convention the Cross-Attention Layer looks like:

Q\colorbox{#b2f2bb}{Q} is the given query vector input

X\colorbox{#a5d8ff}{X} is the given data vector input

K=XWK{\colorbox{#ffd8a8}{K}} = XW_K

V=XWV{\colorbox{#d0bfff}{V}} = XW_V

E=QKT/DQE = QK^T /\sqrt{D_Q}

A=softmax(E,dim=1)A= softmax(E, dim=1)

Y=AV{\colorbox{#ffec99}{Y}} = AV

XQE A E A E A YQE A E A E A YQE A E A E A YQE A E A E A YKVXKVXKV111,31,31,21,21,11,1122,32,32,22,22,12,1233,33,33,23,23,13,1344,34,34,24,24,14,1411222333softmax( )prod( ), sum( )

Another variant is a Self-Attention Layer, where the only input is XX:

XRN×Din\colorbox{#a5d8ff}{X} \in \mathbb{R}^{N\times D_{in}} is the given data vector input

Q=XWQRN×Dout{\colorbox{#b2f2bb}{Q}} = XW_Q \in \mathbb{R}^{N\times D_{out}}, WQRDin×DoutW_Q \in \mathbb{R}^{D_{in}\times D_{out}} being a new learnable matrix.

K=XWKRN×Dout{\colorbox{#ffd8a8}{K}} = XW_K \in \mathbb{R}^{N\times D_{out}}, WKRDin×DoutW_K \in \mathbb{R}^{D_{in}\times D_{out}}

V=XWVRN×Dout{\colorbox{#d0bfff}{V}} = XW_V \in \mathbb{R}^{N\times D_{out}}, WVRDin×DoutW_V \in \mathbb{R}^{D_{in}\times D_{out}}

E=QKT/DQRN×NE = QK^T /\sqrt{D_Q} \in \mathbb{R}^{N\times N}

A=softmax(E,dim=1)RN×NA= softmax(E, dim=1) \in \mathbb{R}^{N\times N}

Y=AVRN×Dout{\colorbox{#ffec99}{Y}} = AV \in \mathbb{R}^{N\times D_{out}}

QXE A E A E A YQXE A E A E A YQXE A E A E A YKVKVKV111,31,31,21,21,11,11222,32,32,22,22,12,12333,33,33,23,23,13,13112233softmax( )prod( ), sum( )

In practice we usually use multi head attention, where the model computes several attention heads (HH) in parallel, each with its own learned projections:

XRN×D{\colorbox{#a5d8ff}{X}} \in \mathbb{R}^{N\times D}

Q=XWQRH×N×DH{\colorbox{#b2f2bb}{Q}} = \colorbox{#a5d8ff}{X}W_Q \in \mathbb{R}^{H \times N\times D_{H}}, WQRD×HDHW_Q \in \mathbb{R}^{D\times HD_{H}}

K=XWKRH×N×DH{\colorbox{#ffd8a8}{K}} = \colorbox{#a5d8ff}{X}W_K \in \mathbb{R}^{H \times N\times D_{H}}, WKRD×HDHW_K \in \mathbb{R}^{D\times HD_{H}}

V=XWVRH×N×DH{\colorbox{#d0bfff}{V}} = \colorbox{#a5d8ff}{X}W_V \in \mathbb{R}^{H \times N\times D_{H}}, WVRD×HDHW_V \in \mathbb{R}^{D\times HD_{H}}

E=QE = {\colorbox{#b2f2bb}{Q}} KT/DQRH×N×N{\colorbox{#ffd8a8}{K}}^T / \sqrt{D_Q} \in \mathbb{R}^{H\times N \times N}

A=softmax(E,dim=1)RH×N×NA= softmax(E, dim=1) \in \mathbb{R}^{H\times N \times N}

Y=AVRN×HDH{\colorbox{#ffec99}{Y}} = A{\colorbox{#d0bfff}{V}} \in \mathbb{R}^{N\times HD_{H}}

O=YWORN×D{\colorbox{#ff8787}{O}} = {\colorbox{#ffec99}{Y}}W_O \in \mathbb{R}^{N\times D}, WORHDH×DW_O \in \mathbb{R}^{HD_H\times D} being a new learnable matrix to fuse the output of each head.

DHD_{H} being the head dimension and usually DH=D/HD_{H} = D / H

XXX123QXE A E A E A YQXE A E A E A YQXE A E A E A YKVKVKV111,31,31,21,21,11,11222,32,32,22,22,12,12333,33,33,23,23,13,13112233softmax( )prod( ), sum( )QXE A E A E A YQXE A E A E A YQXE A E A E A YKVKVKV111,31,31,21,21,11,11222,32,32,22,22,12,12333,33,33,23,23,13,13112233softmax( )prod( ), sum( )QXE A E A E A YQXE A E A E A YQXE A E A E A YKVKVKV111,31,31,21,21,11,11222,32,32,22,22,12,12333,33,33,23,23,13,13112233softmax( )prod( ), sum( )O1Y 1,1Y 1,2Y 1,3Y 2,1Y 2,2Y 2,3O2Y 3,1Y 3,2Y 3,3O3
e.g. H = 3

Which as a great surprise, it can all be computed with 4 big matmuls:

  1. QKV Projection

    [Q    K    V]=X[WQ    WK    WV][\,{\colorbox{#b2f2bb}{Q}} \;|\; {\colorbox{#ffd8a8}{K}} \;|\; {\colorbox{#d0bfff}{V}}\,] = {\colorbox{#a5d8ff}{X}}[\,W_Q \;|\; W_K \;|\; W_V\,]

    [Q    K    V]RN×3HDH=(N×D)(D×3HDH)[\,Q \;|\; K \;|\; V\,] \in \mathbb{R}^{N \times 3HD_H} = (N \times D) (D \times 3HD_H)

  2. QK Similarity

    E=QKTE = {\colorbox{#b2f2bb}{Q}}{\colorbox{#ffd8a8}{K}}^T

    ERH×N×N=(H×N×DH)(H×DH×N)E \in \mathbb{R}^{H \times N \times N} = (H\times N\times D_H)\,(H\times D_H\times N)

  3. V-Weighting

    Y=AV{\colorbox{#ffec99}{Y}} = A{\colorbox{#d0bfff}{V}}

    YRH×N×DH=(H×N×N)(H×N×DH)Y \in \mathbb{R}^{H\times N\times D_H}= (H\times N\times N)(H\times N\times D_H)

    Reshape to YRN×HDH{\colorbox{#ffec99}{Y}} \in \mathbb{R}^{N \times HD_H}

  4. Output Projection

    O=YWO{\colorbox{#ff8787}{O}} = {\colorbox{#ffec99}{Y}}W_O

    ORN×D=(N×HDH)(HDH×D)O \in \mathbb{R}^{N \times D} = (N \times HD_H) (HD_H \times D)


The Transformer

A Transformer block first applies multi-head self-attention then adds a residual connection and layer-normalizes; it then applies a position-wise feed-forward network (FFN/MLP) independently to each vector, followed by another residual addition and a final layer normalization:

x1y1MLPx2y2MLPx3y3MLPx4y4MLPSelf-AttentionLayer NormalizationLayer Normalization++

A Transformer(Vaswani et al., 2017) is built by stacking multiple Transformer blocks sequentially:

MLPMLPMLPMLPMLPMLPMLPMLPMLPMLPMLPMLPSelf-AttentionSelf-AttentionSelf-AttentionLayer NormalizationLayer NormalizationLayer NormalizationLayer NormalizationLayer NormalizationLayer Normalization++++++

That is the post-LN (LayerNorm after sublayer) description. Many modern GPT-style models are pre-LN (LayerNorm before sublayer)

The size of which has only grown:

Original: 12 blocks; D=1024, H=16, N=512, 213M params

GPT-2: 48 blocks; D=1600, H=25, N=1024, 1.5B params

GPT-3: 96 blocks; D=12288, H=96, N=2048, 175B params


Transformers are being successfully applied beyond text to images with ViT (Dosovitskiy et al., 2020), to video with ViViT, to multimodal image-text learning with CLIP, and also to audio, protein sequences, time series...

Since its introduction, it has been extended with many architectural and domain specific refinements, the core ideas presented here remain the essential conceptual starting point.

One variation worth mentioning is Mixture of Experts (MoE). Instead of a single dense feed-forward network in each block, we learn E different MLPs, called experts. A routing function sends each token to a subset of them. This allows the model to scale total parameter count significantly while keeping per token compute much lower than activating all experts.

Other important refinements include causal masking, which restricts token tt from attending to future tokens and is central to autoregressive decoders in the original Transformer, and Rotary Position Embeddings (RoPE), which encode positional information directly into attention.

← Back to blog