--- Attention mechanism
The attention mechanism allows a model to dynamically focus on different parts of the input sequence when generating each part of the output sequence, rather than treating all input tokens as equally important.
The first widely used attention mechanism was introduced in the 2014 paper “Neural Machine Translation by Jointly Learning to Align and Translate” by Bahdanau, Cho, and Bengio. It's an RNN-based attention model. (For key attention concepts and a basic example of how to compute attention scores, see the RNN-with-Attention notes)
--- Self-attention
Transformers use self-attention, whereas earlier models mainly used attention as cross-attention between the encoder and decoder.
Self-attention is still a form of attention. We'll show a detailed example in the "Self-Attention" section. It follows the standard mechanism: use Q and K to compute attention scores and weights, and then use V and the weights to compute context vectors. Self-attention allows the model to determine the importance of different parts of the input sequence when making predictions.
Takeaways:
- When was positional encoding first proposed?
- What is the usage of positional encoding, and what function is commonly used?
- What is the difference between cross-attention and self-attention?
- When was multi-head attention first proposed?
- In Multi-Head Attention, are the input Q, K, V for each head the same? What is the output of each head?
- Why must the output dimension after concatenation in Multi-Head Attention be equal to the input embedding dimension?
- Can parallelism be achieved in a transformer based next token prediction task? Any difference between training and inference?
- In the inference process of next token prediction task, why we have K, V cache but don't have Q cache?
- What is the difference between
nn.TransformerEncoderLayerandnn.TransformerEncoder? What are the input & output of them? If I want to implement a certain functionality, can I achieve it with either one? - What is the difference between
nn.TransformerDecoderLayerandnn.TransformerDecoder? What are the input & output of them? If I want to implement a certain functionality, can I achieve it with either one? - What is the input and output of
nn.Transformer?
--- Positional Encoding¶
Positional Encoding was first introduced in the Transformer model proposed in the 2017 paper “Attention is All You Need.” Positional encoding is used to inject information about the relative or absolute position of tokens in the sequence. They proposed two types
- fixed sinusoidal positional encoding
- learnable positional embeddings.
Positional Encoding is introduced because the Transformer architecture removes recurrence and convolution, it lacks an inherent sense of word order. Transformers do not inherently understand the order of tokens in the input (they treat the input sequence as a set of vectors without any sequential structure). So we need positional encoding.
Implementation Highlights¶
Sinusoidal Functions: The positional encoding uses sinusoidal functions of different frequencies to generate a unique positional vector for each position. This method ensures that each position has a distinct encoding and that similar positions have similar encodings, capturing the sequential nature of the data. Sinusoidal functions include:
- sin(x)
- cos(x)
- Or any of their variants (sin(ax), cos(bx), sin(kx+c), etc.)
The sine and cosine positional encoding formulas from "Attention Is All You Need": $$ PE_{(pos,2i)} = \sin\left(pos / 10000^{2i/d_{\text{model}}}\right) $$ $$ PE_{(pos,2i+1)} = \cos\left(pos / 10000^{2i/d_{\text{model}}}\right) $$ Meaning:
- pos: position (0, 1, 2, 3, ...)
- i: dimension index
- $d_{\text{model}}$: embedding dimension
Equations to Code:
position: A tensor representing the positions in the sequence.position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
div_term: A tensor that scales the positions to generate different frequencies for the sinusoidal functions.The denominator in the formula: $$ 10000^{2i/d_{\text{model}}} $$ Expressed in exponential form: $$ 10000^{2i/d_{\text{model}}} = \exp\left(\log(10000) \cdot \frac{2i}{d_{\text{model}}}\right) $$ Code:
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
Sinusoidal Values:
pe[:, 0::2] = torch.sin(position * div_term) pe[:, 1::2] = torch.cos(position * div_term)
pe[:, 0::2]: Assigns the sine function values to even indices in the encoding.pe[:, 1::2]: Assigns the cosine function values to odd indices in the encoding.
Usage: The positional encoding is added to the input embeddings, which combines the positional information with the token representations.
def forward(self, x): return x + self.pe[:x.size(0), :]
--- Self-Attention¶
In 2016, Google's paper "A Structured Self-Attentive Sentence Embedding" had already used self-attention for sentence encoding, but at that time it was only used as an auxiliary module for RNNs. Transformer was the first model to systematically and extensively use self-attention as its core mechanism.
What is the difference between cross-attention and self-attention? -- The way Q, K, and V are chosen.
Self-attention: Q, K, and V all come from the same sequence (such as within the encoder or decoder). Cross-attention: Q comes from the target sequence (decoder), while K and V come from the source sequence (encoder).
Self-Attention Mechanism¶
Self-attention means each element attends to itself and others in the same sequence. For a sequence,
- Query (Q): Represents the current word (or token) for which we are calculating the attention. (Each word will be used as Q once)
- Key (K): Represents all the words (or tokens) in the sequence.
- Value (V): Also represents all the words (or tokens) in the sequence.
Toy Math Example: Processing the Sentence "I love AI"
Input Embedding
Suppose we have a simple input sentence: "I love AI". For simplicity, let's use the following embeddings for these words:
- "I" -> $[0.1, 0.2]$
- "love" -> $[0.3, 0.4]$
- "AI" -> $[0.5, 0.6]$
So, the input sequence $X$ is:
X = [[0.1, 0.2], [0.3, 0.4], [0.5, 0.6]]Positional Encoding
To provide positional information, we add simple positional encodings:
- Position 0 -> $[0.0, 0.1]$
- Position 1 -> $[0.1, 0.2]$
- Position 2 -> $[0.2, 0.3]$
Adding these to the embeddings, we get:
Positional Encoded X = [[0.1 + 0.0, 0.2 + 0.1], [0.3 + 0.1, 0.4 + 0.2], [0.5 + 0.2, 0.6 + 0.3]] = [[0.1, 0.3], [0.4, 0.6], [0.7, 0.9]]Linear Transformation to Obtain Q, K, V
We perform linear transformations to get Query (Q), Key (K), and Value (V) matrices. For simplicity, assume the following weight matrices $W_Q$, $W_K$, and $W_V$:
W_Q = W_K = W_V = [[1, 0], [0, 1]]Therefore:
Q = K = V = [[0.1, 0.3], [0.4, 0.6], [0.7, 0.9]]Compute Attention Scores
Compute the attention scores using the dot product of Q and K^T:
Attention scores = Q * K^T = [[0.1, 0.3], [0.4, 0.6], [0.7, 0.9]] * [[0.1, 0.4, 0.7], [0.3, 0.6, 0.9]] = [[(0.1*0.1 + 0.3*0.3), (0.1*0.4 + 0.3*0.6), (0.1*0.7 + 0.3*0.9)], [(0.4*0.1 + 0.6*0.3), (0.4*0.4 + 0.6*0.6), (0.4*0.7 + 0.6*0.9)], [(0.7*0.1 + 0.9*0.3), (0.7*0.4 + 0.9*0.6), (0.7*0.7 + 0.9*0.9)]] = [[0.1, 0.2, 0.3], [0.2, 0.52, 0.84], [0.3, 0.84, 1.38]]Apply Softmax to Get Attention Weights
Normalize the scores using the softmax function:
Attention weights = softmax(Attention scores) For simplicity, let's approximate the softmax results: ≈ [[0.3, 0.3, 0.4], [0.2, 0.4, 0.4], [0.2, 0.3, 0.5]]Compute the Weighted Sum of Values
Use the attention weights to compute the weighted sum of the value vectors:
Self-Attention output = Attention weights * V = [[0.3, 0.3, 0.4], [0.2, 0.4, 0.4], [0.2, 0.3, 0.5]] * [[0.1, 0.3], [0.4, 0.6], [0.7, 0.9]] = [[(0.3*0.1 + 0.3*0.4 + 0.4*0.7), (0.3*0.3 + 0.3*0.6 + 0.4*0.9)], [(0.2*0.1 + 0.4*0.4 + 0.4*0.7), (0.2*0.3 + 0.4*0.6 + 0.4*0.9)], [(0.2*0.1 + 0.3*0.4 + 0.5*0.7), (0.2*0.3 + 0.3*0.6 + 0.5*0.9)]] = [[0.46, 0.72], [0.46, 0.72], [0.51, 0.78]]
In the above example, the calculations for attention scores, attention weights, and the weighted sums for each word ("I", "love", "AI") are independent and can be executed simultaneously.
--- Multi-head Attention¶
In the 2017 paper "Attention is All You Need", the authors first proposed Multi-Head Attention, as an extension of Self-Attention and Cross-Attention that allows the model to learn different representations from multiple subspaces in parallel. It is one of the key innovations that led to the success of the Transformer. Previous attention mechanisms (such as Bahdanau Attention, Luong Attention, or Self-Attention) were all single-head, without the concept of "multi-head".
Multi-Head Attention is like having a team of note-takers, where each person focuses on a different aspect of the same lecture: one listens for key facts, another pays attention to tone and emotion, another tracks cause-and-effect. Afterward, they combine their notes to form a more complete understanding. This allows the model to focus on different parts of the input in multiple ways, enhancing its ability to capture various aspects of the input data.
Like self-attention, Multi-Head Attention does not rely exclusively on the Transformer architecture and can also be used in RNNs.
Compute Multi-head Attention¶
Multi-head attention involves computing multiple sets of alignment scores in parallel, each with different learned linear transformations of the same Q, K, V. Each head computes attention score (usually scaled dot-product, other alignment functions are also used occasionally), and then computes a context vector as the output of each head.
$$\text{head}_i = \text{Attention}(QW_{iQ}, KW_{iK}, VW_{iV})$$
Where $W_{iQ}$, $W_{iK}$, and $W_{iV}$ are learned projection matrices for the $i$-th head. Finally we concatenate outputs of all heads and (maybe) linearly transform it with with another learned matrix:
$$\text{MultiHead}(Q, K, V) = \text{Concat}(\text{head}_1, \ldots, \text{head}_h) W_O$$
Where $W_O$ is a learned output projection matrix. In Multi-Head Attention, the final output $\text{MultiHead}(Q, K, V)$ must have the same dimension as the input embedding. This is because the MHA output is added to the input via a residual connection, and the dimensions must match: $$ \text{MHA}(x) + x $$ The final linear layer $W_O$ mainly serves this purpose. Of course, if $\text{Concat}(\text{head}_1, \ldots, \text{head}_h)$ already matches the input embedding size, $W_O$ may not be necessary.
Toy Math Example
- Embedding Dimension (embed_dim): 4
- Number of Heads (num_heads): 2
- Output Dimension per Head:
embed_dim / num_heads = 4 / 2 = 2
Assume we have 3 keys and values, and 1 query.
Head 1: $W_Q^{(1)} = \begin{bmatrix} 0.1 & 0.2 \\ 0.3 & 0.4 \\ 0.5 & 0.6 \\ 0.7 & 0.8 \end{bmatrix} \qquad W_K^{(1)} = \begin{bmatrix} 0.9 & 1.0 \\ 1.1 & 1.2 \\ 1.3 & 1.4 \\ 1.5 & 1.6 \end{bmatrix} \qquad W_V^{(1)} = \begin{bmatrix} 1.7 & 1.8 \\ 1.9 & 2.0 \\ 2.1 & 2.2 \\ 2.3 & 2.4 \end{bmatrix}$
Head 2: $W_Q^{(2)} = \begin{bmatrix} 2.5 & 2.6 \\ 2.7 & 2.8 \\ 2.9 & 3.0 \\ 3.1 & 3.2 \end{bmatrix}$,$W_K^{(2)} = \begin{bmatrix} 3.3 & 3.4 \\ 3.5 & 3.6 \\ 3.7 & 3.8 \\ 3.9 & 4.0 \end{bmatrix}$,$W_V^{(2)} = \begin{bmatrix} 4.1 & 4.2 \\ 4.3 & 4.4 \\ 4.5 & 4.6 \\ 4.7 & 4.8 \end{bmatrix}$
Let’s assume our input sequence for queries, keys, and values are:
Input Query: $\textbf{Q} = \begin{bmatrix} 0.5 & 0.6 & 0.7 & 0.8 \end{bmatrix}$
Input Key: $\textbf{K} = \begin{bmatrix} 0.1 & 0.2 & 0.3 & 0.4 \\ 0.5 & 0.6 & 0.7 & 0.8 \\ 0.9 & 1.0 & 1.1 & 1.2 \end{bmatrix}$
Input Value: $\textbf{V} = \begin{bmatrix} 1.3 & 1.4 & 1.5 & 1.6 \\ 1.7 & 1.8 & 1.9 & 2.0 \\ 2.1 & 2.2 & 2.3 & 2.4 \end{bmatrix}$
Head 1 Linear Transformations
Query:
$$\textbf{Q}_{\text{Head 1}} = \textbf{Q} \cdot W_Q^{(1)} = \begin{bmatrix} 0.5 & 0.6 & 0.7 & 0.8 \end{bmatrix} \cdot \begin{bmatrix} 0.1 & 0.2 \\ 0.3 & 0.4 \\ 0.5 & 0.6 \\ 0.7 & 0.8 \end{bmatrix} = \begin{bmatrix} 1.44 & 1.40 \end{bmatrix}$$
Key:
$$\textbf{K}_{\text{Head 1}} = \textbf{K} \cdot W_K^{(1)} = \begin{bmatrix} 0.1 & 0.2 & 0.3 & 0.4 \\ 0.5 & 0.6 & 0.7 & 0.8 \\ 0.9 & 1.0 & 1.1 & 1.2 \end{bmatrix} \cdot \begin{bmatrix} 0.9 & 1.0 \\ 1.1 & 1.2 \\ 1.3 & 1.4 \\ 1.5 & 1.6 \end{bmatrix} = \begin{bmatrix} 1.30 & 1.40 \\ 3.22 & 2.48 \\ 5.14 & 5.56 \end{bmatrix}$$
Value:
$$\textbf{V}_{\text{Head 1}} = \textbf{V} \cdot W_V^{(1)} = \begin{bmatrix} 1.3 & 1.4 & 1.5 & 1.6 \\ 1.7 & 1.8 & 1.9 & 2.0 \\ 2.1 & 2.2 & 2.3 & 2.4 \end{bmatrix} \cdot \begin{bmatrix} 1.7 & 1.8 \\ 1.9 & 2.0 \\ 2.1 & 2.2 \\ 2.3 & 2.4 \end{bmatrix} = \begin{bmatrix} 11.70 & 12.28 \\ 15.90 & 15.64 \\ 17.10 & 18.00 \end{bmatrix}$$
Head 2 Linear Transformations
Query:
$$\textbf{Q}_{\text{Head 2}} = \textbf{Q} \cdot W_Q^{(2)} = \begin{bmatrix} 0.5 & 0.6 & 0.7 & 0.8 \end{bmatrix} \cdot \begin{bmatrix} 2.5 & 2.6 \\ 2.7 & 2.8 \\ 2.9 & 3.0 \\ 3.1 & 3.2 \end{bmatrix} = \begin{bmatrix} 7.38 & 7.64 \end{bmatrix}$$
Key:
$$\textbf{K}_{\text{Head 2}} = \textbf{K} \cdot W_K^{(2)} = \begin{bmatrix} 0.1 & 0.2 & 0.3 & 0.4 \\ 0.5 & 0.6 & 0.7 & 0.8 \\ 0.9 & 1.0 & 1.1 & 1.2 \end{bmatrix} \cdot \begin{bmatrix} 3.3 & 3.4 \\ 3.5 & 3.6 \\ 3.7 & 3.8 \\ 3.9 & 4.0 \end{bmatrix} = \begin{bmatrix} 3.70 & 3.80 \\ 9.46 & 9.72 \\ 15.22 & 15.64 \end{bmatrix}$$
Value:
$$\textbf{V}_{\text{Head 2}} = \textbf{V} \cdot W_V^{(2)} = \begin{bmatrix} 1.3 & 1.4 & 1.5 & 1.6 \\ 1.7 & 1.8 & 1.9 & 2.0 \\ 2.1 & 2.2 & 2.3 & 2.4 \end{bmatrix} \cdot \begin{bmatrix} 4.1 & 4.2 \\ 4.3 & 4.4 \\ 4.5 & 4.6 \\ 4.7 & 4.8 \end{bmatrix} = \begin{bmatrix} 25.62 & 26.20 \\ 32.66 & 33.40 \\ 39.74 & 40.36 \end{bmatrix}$$
Compute Attention Scores
$$\text{Attention Scores}_{\text{Head 1}} = \textbf{Q}_{\text{Head 1}} \cdot \textbf{K}_{\text{Head 1}}^\top$$
$$\text{Attention Scores}_{\text{Head 1}} = \begin{bmatrix} 1.44 \cdot 1.30 + 1.40 \cdot 1.40 \\ 1.44 \cdot 3.22 + 1.40 \cdot 2.48 \\ 1.44 \cdot 5.14 + 1.40 \cdot 5.56 \end{bmatrix} = \begin{bmatrix} 1.872 + 1.96 \\ 4.636 + 3.472 \\ 7.4256 + 7.84 \end{bmatrix} = \begin{bmatrix} 3.832 \\ 8.108 \\ 15.2656 \end{bmatrix}$$
$$\text{Attention Scores}_{\text{Head 2}} = \textbf{Q}_{\text{Head 2}} \cdot \textbf{K}_{\text{Head 2}}^\top$$
$$\text{Attention Scores}_{\text{Head 2}} = \begin{bmatrix} 7.38 \cdot 3.70 + 7.64 \cdot 3.80 \\ 7.38 \cdot 9.46 + 7.64 \cdot 9.72 \\ 7.38 \cdot 15.22 + 7.64 \cdot 15.64 \end{bmatrix} = \begin{bmatrix} 27.306 + 29.016 \\ 69.747 + 74.470 \\ 112.4076 + 119.0456 \end{bmatrix} = \begin{bmatrix} 56.322 \\ 144.217 \\ 231.4532 \end{bmatrix}$$
$$\text{Softmax}_{\text{Head 1}} = \frac{e^{3.832}}{\text{Sum}}, \frac{e^{8.108}}{\text{Sum}}, \frac{e^{15.2656}}{\text{Sum}} \approx \begin{bmatrix} 0.00091 \\ 0.0645 \\ 0.9346 \end{bmatrix}$$
$$\text{Softmax}_{\text{Head 2}} \approx \begin{bmatrix} 1.0 \times 10^{-76} \\ 1.6 \times 10^{-38} \\ 1.0 \end{bmatrix}$$
Context Vectors
$$\text{Context}_{\text{Head 1}} = \text{Softmax}_{\text{Head 1}} \cdot \textbf{V}_{\text{Head 1}}$$
$$\text{Context}_{\text{Head 1}} = \begin{bmatrix} 0.00091 & 0.0645 & 0.9346 \end{bmatrix} \cdot \begin{bmatrix} 11.70 & 12.28 \\ 15.90 & 15.64 \\ 17.10 & 18.00 \end{bmatrix} = \begin{bmatrix} 17.0164 & 17.8313 \end{bmatrix}$$
$$\text{Context}_{\text{Head 2}} = \text{Softmax}_{\text{Head 2}} \cdot \textbf{V}_{\text{Head 2}}$$
$$\text{Context}_{\text{Head 2}} = \begin{bmatrix} 1.0 \times 10^{-76} & 1.6 \times 10^{-38} & 1.0 \end{bmatrix} \cdot \begin{bmatrix} 25.62 & 26.20 \\ 32.66 & 33.40 \\ 39.74 & 40.36 \end{bmatrix} = \begin{bmatrix} 39.74 & 40.36 \end{bmatrix}$$
Example Code of Multi-head Attention¶
A simplified version of how MultiheadAttention class is implemented in torch. As we showed in the previous formula, the output of MultiheadAttention must have the same dimension as the input Q, K, V.
import torch
import torch.nn as nn
import torch.nn.functional as F
class MultiheadAttention(nn.Module):
def __init__(self, embed_dim, num_heads, dropout=0.0):
super(MultiheadAttention, self).__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
self.dropout = dropout
self.head_dim = embed_dim // num_heads
assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
# Linear layers for query, key, and value
self.q_proj = nn.Linear(embed_dim, embed_dim)
self.k_proj = nn.Linear(embed_dim, embed_dim)
self.v_proj = nn.Linear(embed_dim, embed_dim)
# Output projection
self.out_proj = nn.Linear(embed_dim, embed_dim)
def forward(self, query, key, value, key_padding_mask=None, need_weights=True, attn_mask=None):
# Project input tensors to query, key, and value
query = self.q_proj(query)
key = self.k_proj(key)
value = self.v_proj(value)
# Reshape for multi-head attention
batch_size, seq_len, embed_dim = query.size()
query = query.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
key = key.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
value = value.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
# Scaled dot-product attention
attn_output, attn_output_weights = self.scaled_dot_product_attention(query, key, value, attn_mask, key_padding_mask)
# Concatenate heads and project
attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, seq_len, embed_dim)
attn_output = self.out_proj(attn_output)
if need_weights:
return attn_output, attn_output_weights
else:
return attn_output
def scaled_dot_product_attention(self, query, key, value, attn_mask=None, key_padding_mask=None):
# Calculate attention scores
attn_scores = torch.matmul(query, key.transpose(-2, -1)) / (self.head_dim ** 0.5)
if attn_mask is not None:
attn_scores = attn_scores + attn_mask
if key_padding_mask is not None:
attn_scores = attn_scores.masked_fill(key_padding_mask.unsqueeze(1).unsqueeze(2), float('-inf'))
# Normalize attention scores to probabilities
attn_weights = F.softmax(attn_scores, dim=-1)
if self.dropout > 0.0:
attn_weights = F.dropout(attn_weights, p=self.dropout)
# Weighted sum of values (Each item of the output corresponds to the weighted sum of V vectors w.r.t. the certain query)
attn_output = torch.matmul(attn_weights, value)
return attn_output, attn_weights
--- Transformers¶
Attention can be used as a module in a variety of models such as RNNs and MLPs. However, what makes the Transformer special is that it builds its entire architecture solely based on Attention, replacing RNNs/CNNs for handling sequential data.
In almost all cases, Transformers use Multi-Head Attention and Positional Encoding.
Easy Parallelism In Transformer
In RNNs, whether it's cross-attention or self-attention, computations cannot be parallelized because the hidden states (from which the Q, K and V in attention may be picked) are generated step by step. Attention can only be computed within the currently "generated" window; you can't apply attention to the entire input sequence at once.
Transformers, on the other hand, can compute in parallel because all their Q, K, and V vectors are obtained by applying linear transformations to the entire sequence at once, without relying on the output of the previous token. For the whole sequence, each token independently generates its own Q (query vector), and each Q's attention is computed parallelly. Note that in transformers, this parallelism specifically refers to parallel attention computation.
Parallelism in transformer is the key to make models bigger! The parallel attention computation is a significant advantage of transformers compared with the original RNN attention model, making transformers highly efficient for large-scale sequence processing tasks. Transformers are the foundation for many state-of-the-art models in natural language processing, such as BERT and GPT.
Next Token Prediction Task: Slow Inference
For Next Token Prediction, Transformers can be parallelized during training because the target sequence is known, and targets at every position can be predicted simultaneously, simply by controlling the flow of information with masks. When training an autoregressive model (such as GPT or a translation model decoder), you already have the complete target output sequence, for example:
Input: [A, B, C, D]
Target: [B, C, D, <EOS>]
This means for each token, you know the “ground truth” next word. So you can:
- Feed the entire target sequence into the model as decoder input, e.g.,
[A, B, C, D] - The model uses masked self-attention to predict the next word at every position simultaneously
- The first position predicts
B, the second predictsC, and so on; - All Q/K/V, attention, linear layers, etc., can be computed in parallel!
- The first position predicts
- Add a mask to prevent information leakage
- Although we feed in the full sequence, a causal mask ensures each position can only see what comes before it (it prevents peeking at the correct answer);
- This preserves the property of autoregressive modeling.
For Next Token Prediction, Transformers are highly parallel during training, but not during inference. The inference process is essentially like an RNN: generating tokens step by step, in order, where each step depends on the output of the previous step. Specifically, at step t, Q comes from the token at position t-1, K and V are from all previous tokens, so step t can only start after step t-1 has been predicted. This style of generation cannot process each generation position in parallel.
Also, transformer inference is slow because attention computation is quadratic complexity (O(n²)).
KV Cache
KV Cache improves inference speed in next-token generation inference task by caching repeated computations of K, V.
KV Cache means storing the K and V vectors generated by linear transformations of previous tokens, so that at each inference step, you don't have to recompute these for all previous tokens. In attention computation, for input token $x_t$, we compute:
$K_t = x_t \cdot W^K$
$V_t = x_t \cdot W^V$
If you don't use KV cache, at each generation step you'd have to repeat these linear transformations for all previous $x_1, x_2, ..., x_{t-1}$, which is very inefficient.
However, KV Cache does not reduce the computational complexity of Attention, which remains O(n²)!
--- nn.Transformer Related Class¶
The key component for implementation is the nn.Transformer module. It is a complete Encoder–Decoder Transformer, internally consisting of:
- Encoder: uses
TransformerEncoderLayer(self-attention) - Decoder: uses
TransformerDecoderLayer(masked self-attention + cross-attention) Therefore, PyTorch’s nn.Transformer fully implements the standard Transformer architecture (as in “Attention Is All You Need”) with cross-attention.
We can see that in TransformerEncoderLayer and TransformerDecoderLayer, the structure is not just a multi-head attention; instead, it is a sequence of multi-head attention followed by a feedforward network (multi-head att -> FFN). In fact, in the original “Attention Is All You Need” paper, each encoder and decoder layer in the Transformer is structured as multi-head attention followed by FFN. We may ask: Attention itself is already very powerful, so why do we still need to add an FFN after each layer? Why is it not enough to just add another activation?
Attention is essentially a weighted average calculator: it can only reorder and aggregate information, but it is not a feature transformer. What multi-head attention essentially does is use attention weights to compute weighted sums of features, so the new features created are just linear combinations of the original features. From the perspective of linear algebra, linear combinations do not increase the dimensionality of the space, so attention does not create a new feature space. On the other hand, the linear layer in the FFN can truly change the feature space dimension, and with the activation function, the transformation done by the FFN becomes nonlinear and expressive enough to fit complex functions—completely transforming the feature space into a new one.
Details of nn.TransformerEncoderLayer and nn.TransformerEncoder¶
For nn.TransformerEncoderLayer and nn.TransformerEncoder, we only need to understand the structure of nn.TransformerEncoderLayer, because nn.TransformerEncoder is simply a stack of nn.TransformerEncoderLayers.
class TransformerEncoderLayer(nn.Module):
def __init__(self, d_model=512, nhead=8, dim_feedforward=2048, dropout=0.1):
super().__init__()
self.self_attn = nn.MultiheadAttention(
embed_dim=d_model,
num_heads=nhead,
batch_first=True
)
# FFN shape of ffn out == shape of ffn in
self.linear1 = nn.Linear(d_model, dim_feedforward)
self.linear2 = nn.Linear(dim_feedforward, d_model)
# LayerNorm + Dropout
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(dropout)
def forward(self, src, src_mask=None, src_key_padding_mask=None):
"""
src: (N, S, E)
src_mask: (S, S) or None
src_key_padding_mask: (N, S) or None
"""
attn_out, _ = self.self_attn( # shape of attn_out is the same as src: (N, S, E)
src, src, src, # let Q=K=V=src(self-attention)
attn_mask=src_mask,
key_padding_mask=src_key_padding_mask
)
# residual connection: src + attn_out -> norm
src = self.norm1(src + self.dropout1(attn_out))
# residual connection: src + ffn_out -> norm
ffn_out = self.linear2(F.relu(self.linear1(src)))
out = self.norm2(src + self.dropout2(ffn_out))
return out
A flow chart to visualize the nn.TransformerEncoderLayer architecture:
┌───────────────┐
src ───────▶│ Multi-Head SA │─────┐
└───────────────┘ │
▼
Add src & Norm (residual connection)
│
▼
┌──────────────────────┐
│ FFN │
└──────────────────────┘
│
▼
Add src & Norm (residual connection)
For TransformerEncoderLayer and TransformerEncoder, the output and input dimensions are exactly the same.
| Module | Input Shape | Output Shape |
|---|---|---|
| TransformerEncoderLayer | (N, S, E) | (N, S, E) |
| TransformerEncoder | (N, S, E) | (N, S, E) |
where
- S = sequence length
- N = batch size
- E = embedding dimension for each token (also the dimension after linear transformation to Q, K, V).
Details of nn.TransformerDecoderLayer and nn.TransformerDecoder¶
For nn.TransformerDecoderLayer and nn.TransformerDecoder, we only need to understand the structure of nn.TransformerDecoderLayer, because nn.TransformerDecoder is simply a stack of nn.TransformerDecoderLayers.
class TransformerDecoderLayer(nn.Module):
def __init__(self, d_model=512, nhead=8, dim_feedforward=2048, dropout=0.1):
super().__init__()
# 1. Masked Self-Attention (for target sequence)
self.self_attn = nn.MultiheadAttention(
embed_dim=d_model,
num_heads=nhead,
batch_first=True
)
# 2. Cross-Attention (attend to encoder output)
self.multihead_attn = nn.MultiheadAttention(
embed_dim=d_model,
num_heads=nhead,
batch_first=True
)
# 3. Feed-Forward Network (same as in encoder)
self.linear1 = nn.Linear(d_model, dim_feedforward)
self.linear2 = nn.Linear(dim_feedforward, d_model)
# LayerNorm + Dropout
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.norm3 = nn.LayerNorm(d_model)
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(dropout)
self.dropout3 = nn.Dropout(dropout)
def forward(
self,
tgt, # (N, T, E)
memory, # (N, S, E)
tgt_mask=None, # (T, T)
memory_mask=None, # (T, S)
tgt_key_padding_mask=None, # (N, T)
memory_key_padding_mask=None# (N, S)
):
"""
tgt: target sequence (N, T, E)
memory: encoder output (N, S, E)
tgt_mask: look-ahead mask for self-attention
memory_mask: optional mask for encoder-decoder attention
"""
# --- 1. Masked Self-Attention ---
tgt2, _ = self.self_attn(
tgt, tgt, tgt,
attn_mask=tgt_mask,
key_padding_mask=tgt_key_padding_mask
)
tgt = self.norm1(tgt + self.dropout1(tgt2))
# --- 2. Cross-Attention (Q from tgt, K/V from encoder memory) ---
tgt2, _ = self.multihead_attn(
tgt, memory, memory,
attn_mask=memory_mask,
key_padding_mask=memory_key_padding_mask
)
tgt = self.norm2(tgt + self.dropout2(tgt2))
# --- 3. Feed-Forward Network ---
ffn_out = self.linear2(F.relu(self.linear1(tgt)))
out = self.norm3(tgt + self.dropout3(ffn_out))
return out
A flow chart to visualize the nn.TransformerDecoderLayer architecture:
┌───────────────────┐
tgt ───────▶│ Masked Self-Attn │────┐
└───────────────────┘ │
▼
Add tgt & Norm (residual connection)
│
▼
┌───────────────────┐
memory ────▶│ Cross-Attention │────┐
└───────────────────┘ │
▼
Add tgt & Norm (residual connection)
│
▼
┌──────────────────────┐
│ FFN │
└──────────────────────┘
│
▼
Add tgt & Norm (residual connection)
For TransformerDecoderLayer and TransformerDecoder, the output and input dimensions are exactly the same.
| Module | Input Shape | Output Shape |
|---|---|---|
| TransformerDecoderLayer | (N, T, E) | (N, T, E) |
| TransformerDecoder | (N, T, E) | (N, T, E) |
where
- T = target sequence length
- S = source sequence length (from the encoder output)
- N = batch size
- E = embedding dimension for each token (also the dimension of Q, K, V after linear transformation).
The difference between Encoder and Decoder Layers is that the Decoder Layer contains two attention blocks:
- Masked Self-Attention, where the decoder attends to previous tokens only.
- Cross-Attention, where the decoder attends to the encoder’s output.
Details of nn.Transformer¶
For nn.Transformer, we only need to understand that it is composed of both an encoder and a decoder.
Internally, it combines a stack of nn.TransformerEncoderLayers and a stack of nn.TransformerDecoderLayers into a complete Transformer architecture for sequence-to-sequence modeling.
class Transformer(nn.Module):
def __init__(self, d_model=512, nhead=8, num_encoder_layers=6, num_decoder_layers=6, dim_feedforward=2048, dropout=0.1):
super().__init__()
# Encoder: stack of TransformerEncoderLayers
encoder_layer = nn.TransformerEncoderLayer(
d_model=d_model,
nhead=nhead,
dim_feedforward=dim_feedforward,
dropout=dropout,
batch_first=True
)
self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_encoder_layers)
# Decoder: stack of TransformerDecoderLayers
decoder_layer = nn.TransformerDecoderLayer(
d_model=d_model,
nhead=nhead,
dim_feedforward=dim_feedforward,
dropout=dropout,
batch_first=True
)
self.decoder = nn.TransformerDecoder(decoder_layer, num_layers=num_decoder_layers)
self.d_model = d_model
self.nhead = nhead
def forward(
self,
src, # (N, S, E)
tgt, # (N, T, E)
src_mask=None, # (S, S)
tgt_mask=None, # (T, T)
memory_mask=None, # (T, S)
src_key_padding_mask=None, # (N, S)
tgt_key_padding_mask=None, # (N, T)
memory_key_padding_mask=None# (N, S)
):
"""
src: source sequence (input)
tgt: target sequence (output)
"""
# Encoder: encodes the source sequence
memory = self.encoder(
src,
mask=src_mask,
src_key_padding_mask=src_key_padding_mask
) # (N, S, E)
# Decoder: generates the output sequence
out = self.decoder(
tgt,
memory,
tgt_mask=tgt_mask,
memory_mask=memory_mask,
tgt_key_padding_mask=tgt_key_padding_mask,
memory_key_padding_mask=memory_key_padding_mask
) # (N, T, E)
return out
A flow chart to visualize the overall nn.Transformer architecture:
┌───────────────────────┐
src ───────────▶│ Transformer Encoder │─────┐
└───────────────────────┘ │
▼
memory (encoder output)
│
▼
┌───────────────────────┐
tgt ─────────────────────────────▶│ Transformer Decoder │─────▶ output
└───────────────────────┘
For Transformer, the encoder and decoder parts have matching embedding dimensions, and both preserve input–output dimensionality.
| Module | Input Shape | Output Shape |
|---|---|---|
| TransformerEncoder | (N, S, E) | (N, S, E) |
| TransformerDecoder | (N, T, E) | (N, T, E) |
| Transformer | (src: N, S, E), (tgt: N, T, E) | (N, T, E) |
where
- S = source sequence length
- T = target sequence length
- N = batch size
- E = embedding dimension (shared between encoder and decoder)
Important parameters of nn.Transformer
| Parameter | Meaning | Role / Corresponding Transformer Component |
|---|---|---|
| d_model | The main dimensionality of the Transformer | The dimension of each token embedding, and also the size of Q, K, V vectors after linear projection. |
| nhead | Number of attention heads | Defines how many parallel attention heads are used in the Multi-Head Attention mechanism. |
| num_encoder_layers | Number of encoder layers | Determines how many TransformerEncoderLayers are stacked to form the encoder. |
| num_decoder_layers | Number of decoder layers | Determines how many TransformerDecoderLayers are stacked to form the decoder. |
| dim_feedforward | Hidden dimension in the feed-forward network (FFN) | Typically 4× larger than d_model; controls the capacity of the FFN sub-layer. |
--- Implementation: Using Transformer for Translation (Seq2Seq, Encoder and Decoder)¶
We use a seq2seq translation task here because transformer as implemented in the original paper, is a standard Seq2Seq translation model: the input is a source language sequence and the output is a target language sequence, intended for machine translation tasks.
A sequence start with <bos>, predict the next token iteratively until <eos>. The input and output have same shape. We can regard the output as shift one step of input. Next word is at the last idx of the output.
The decoder has two attention layers: the first is the self attention of the target seq (Q,K,V are all from tgt), and the second is the cross attention where Q is from tgt, K, V are from src embeddings.
As discussed before, the parallelization is limited in inference process. The model does not have access to the full target sequence. It generates one token at a time, so the prediction of the next token depends on the tokens generated so far.
import math
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
# ----------
# Data Preparation
# ----------
# Define a simple dataset
data = [
("I am a student", "Je suis un étudiant"),
("He is a teacher", "Il est un enseignant"),
("She is a nurse", "Elle est une infirmière"),
("I love you", "Je t'aime"),
("How are you?", "Comment ça va?"),
]
# Build a vocabulary
def build_vocab(sentences):
vocab = {"<unk>": 0, "<pad>": 1, "<bos>": 2, "<eos>": 3}
for sentence in sentences:
for word in sentence.split():
if word not in vocab:
vocab[word] = len(vocab)
return vocab
src_sentences = [pair[0] for pair in data]
tgt_sentences = [pair[1] for pair in data]
src_vocab = build_vocab(src_sentences)
tgt_vocab = build_vocab(tgt_sentences)
print("SRC_VOCAB_SIZE:", len(src_vocab))
print("TGT_VOCAB_SIZE ", len(tgt_vocab))
print("src_vocab: ", src_vocab)
print("tgt_vocab: ", tgt_vocab)
# Convert sentences to tensors
def sentence_to_tensor(sentence, vocab):
return torch.tensor([vocab[word] for word in sentence.split()], dtype=torch.long)
class TranslationDataset(Dataset):
def __init__(self, data, src_vocab, tgt_vocab):
self.data = data
self.src_vocab = src_vocab
self.tgt_vocab = tgt_vocab
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
src, tgt = self.data[idx]
src_tensor = sentence_to_tensor(src, self.src_vocab)
tgt_tensor = sentence_to_tensor(tgt, self.tgt_vocab)
return src_tensor, tgt_tensor
# DataLoader
BATCH_SIZE = 2
PAD_IDX = src_vocab["<pad>"]
def collate_fn(batch):
src_batch, tgt_batch = [], []
for src_sample, tgt_sample in batch:
src_batch.append(torch.cat([torch.tensor([src_vocab["<bos>"]]), src_sample, torch.tensor([src_vocab["<eos>"]])]))
tgt_batch.append(torch.cat([torch.tensor([tgt_vocab["<bos>"]]), tgt_sample, torch.tensor([tgt_vocab["<eos>"]])]))
src_batch = nn.utils.rnn.pad_sequence(src_batch, padding_value=PAD_IDX)
tgt_batch = nn.utils.rnn.pad_sequence(tgt_batch, padding_value=PAD_IDX)
return src_batch, tgt_batch
train_loader = DataLoader(TranslationDataset(data, src_vocab, tgt_vocab), batch_size=BATCH_SIZE, collate_fn=collate_fn)
print("Raw data before padding: ")
for src, tgt in data:
src_tensor = sentence_to_tensor(src, src_vocab)
tgt_tensor = sentence_to_tensor(tgt, tgt_vocab)
print("src_tensor: ", src_tensor)
print("tgt_tensor: ", tgt_tensor)
SRC_VOCAB_SIZE: 18
TGT_VOCAB_SIZE 18
src_vocab: {'<unk>': 0, '<pad>': 1, '<bos>': 2, '<eos>': 3, 'I': 4, 'am': 5, 'a': 6, 'student': 7, 'He': 8, 'is': 9, 'teacher': 10, 'She': 11, 'nurse': 12, 'love': 13, 'you': 14, 'How': 15, 'are': 16, 'you?': 17}
tgt_vocab: {'<unk>': 0, '<pad>': 1, '<bos>': 2, '<eos>': 3, 'Je': 4, 'suis': 5, 'un': 6, 'étudiant': 7, 'Il': 8, 'est': 9, 'enseignant': 10, 'Elle': 11, 'une': 12, 'infirmière': 13, "t'aime": 14, 'Comment': 15, 'ça': 16, 'va?': 17}
Raw data before padding:
src_tensor: tensor([4, 5, 6, 7])
tgt_tensor: tensor([4, 5, 6, 7])
src_tensor: tensor([ 8, 9, 6, 10])
tgt_tensor: tensor([ 8, 9, 6, 10])
src_tensor: tensor([11, 9, 6, 12])
tgt_tensor: tensor([11, 9, 12, 13])
src_tensor: tensor([ 4, 13, 14])
tgt_tensor: tensor([ 4, 14])
src_tensor: tensor([15, 16, 17])
tgt_tensor: tensor([15, 16, 17])
# ----------
# Model Architecture
# ----------
class TransformerModel(nn.Module):
def __init__(self, src_vocab_size, tgt_vocab_size, d_model=256, nhead=8, num_encoder_layers=3, num_decoder_layers=3, dim_feedforward=512, dropout=0.1):
super(TransformerModel, self).__init__()
self.d_model = d_model
self.src_embedding = nn.Embedding(src_vocab_size, d_model)
self.tgt_embedding = nn.Embedding(tgt_vocab_size, d_model)
self.positional_encoding = self._generate_positional_encoding(d_model)
self.transformer = nn.Transformer(d_model, nhead, num_encoder_layers, num_decoder_layers, dim_feedforward, dropout)
self.fc_out = nn.Linear(d_model, tgt_vocab_size)
def _generate_positional_encoding(self, d_model, max_len=5000):
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(1) # Add batch dimension
return pe
def forward(self, src, tgt):
src = self.src_embedding(src) * math.sqrt(self.d_model)
tgt = self.tgt_embedding(tgt) * math.sqrt(self.d_model)
# Apply positional encoding
src = src + self.positional_encoding[:src.size(0), :]
tgt = tgt + self.positional_encoding[:tgt.size(0), :]
output = self.transformer(src, tgt)
final_layer = self.fc_out(output)
return final_layer
# ----------
# Training
# ----------
# Training configuration
SRC_VOCAB_SIZE = len(src_vocab)
TGT_VOCAB_SIZE = len(tgt_vocab)
model = TransformerModel(SRC_VOCAB_SIZE, TGT_VOCAB_SIZE)
criterion = nn.CrossEntropyLoss(ignore_index=PAD_IDX)
optimizer = optim.Adam(model.parameters(), lr=0.0005)
# Training loop
NUM_EPOCHS = 10
for epoch in range(NUM_EPOCHS):
model.train()
total_loss = 0
for src_batch, tgt_batch in train_loader:
print("src, tgt: ", src_batch.shape, tgt_batch.shape) # src and tgt after padding
optimizer.zero_grad()
# Remove last token from tgt_batch for input
tgt_input = tgt_batch[:-1, :]
# Forward pass
output = model(src_batch, tgt_input)
# Remove the first token from tgt_batch for target
tgt_out = tgt_batch[1:, :].reshape(-1)
# Calculate loss
loss = criterion(output.reshape(-1, output.shape[-1]), tgt_out)
loss.backward()
optimizer.step()
total_loss += loss.item()
print(f'Epoch {epoch + 1}/{NUM_EPOCHS}, Loss: {total_loss / len(train_loader):.4f}')
/Users/yy544/.local/share/virtualenvs/Cookbook-0P5uvQVm/lib/python3.10/site-packages/torch/nn/modules/transformer.py:307: UserWarning: enable_nested_tensor is True, but self.use_nested_tensor is False because encoder_layer.self_attn.batch_first was not True(use batch_first for better inference performance)
warnings.warn(f"enable_nested_tensor is True, but self.use_nested_tensor is False because {why_not_sparsity_fast_path}")
src, tgt: torch.Size([6, 2]) torch.Size([6, 2]) src, tgt: torch.Size([6, 2]) torch.Size([6, 2]) src, tgt: torch.Size([5, 1]) torch.Size([5, 1]) Epoch 1/10, Loss: 3.0812 src, tgt: torch.Size([6, 2]) torch.Size([6, 2]) src, tgt: torch.Size([6, 2]) torch.Size([6, 2]) src, tgt: torch.Size([5, 1]) torch.Size([5, 1]) Epoch 2/10, Loss: 2.2407 src, tgt: torch.Size([6, 2]) torch.Size([6, 2]) src, tgt: torch.Size([6, 2]) torch.Size([6, 2]) src, tgt: torch.Size([5, 1]) torch.Size([5, 1]) Epoch 3/10, Loss: 1.8884 src, tgt: torch.Size([6, 2]) torch.Size([6, 2]) src, tgt: torch.Size([6, 2]) torch.Size([6, 2]) src, tgt: torch.Size([5, 1]) torch.Size([5, 1]) Epoch 4/10, Loss: 1.3930 src, tgt: torch.Size([6, 2]) torch.Size([6, 2]) src, tgt: torch.Size([6, 2]) torch.Size([6, 2]) src, tgt: torch.Size([5, 1]) torch.Size([5, 1]) Epoch 5/10, Loss: 1.2496 src, tgt: torch.Size([6, 2]) torch.Size([6, 2]) src, tgt: torch.Size([6, 2]) torch.Size([6, 2]) src, tgt: torch.Size([5, 1]) torch.Size([5, 1]) Epoch 6/10, Loss: 0.8894 src, tgt: torch.Size([6, 2]) torch.Size([6, 2]) src, tgt: torch.Size([6, 2]) torch.Size([6, 2]) src, tgt: torch.Size([5, 1]) torch.Size([5, 1]) Epoch 7/10, Loss: 0.5083 src, tgt: torch.Size([6, 2]) torch.Size([6, 2]) src, tgt: torch.Size([6, 2]) torch.Size([6, 2]) src, tgt: torch.Size([5, 1]) torch.Size([5, 1]) Epoch 8/10, Loss: 0.3647 src, tgt: torch.Size([6, 2]) torch.Size([6, 2]) src, tgt: torch.Size([6, 2]) torch.Size([6, 2]) src, tgt: torch.Size([5, 1]) torch.Size([5, 1]) Epoch 9/10, Loss: 0.2143 src, tgt: torch.Size([6, 2]) torch.Size([6, 2]) src, tgt: torch.Size([6, 2]) torch.Size([6, 2]) src, tgt: torch.Size([5, 1]) torch.Size([5, 1]) Epoch 10/10, Loss: 0.1677
# ----------
# Inference
# ----------
def greedy_decode(model, src, max_len, start_symbol):
src = src.unsqueeze(1)
src_mask = model.transformer.generate_square_subsequent_mask(src.size(0)).type(torch.bool)
memory = model.transformer.encoder(model.src_embedding(src) + model.positional_encoding[:src.size(0), :])
ys = torch.ones(1, 1).fill_(start_symbol).type(torch.long) # Start with <bos>, predict the next token iteratively until <eos>
for i in range(max_len-1):
tgt_mask = model.transformer.generate_square_subsequent_mask(ys.size(0)).type(torch.bool)
out = model.transformer.decoder(model.tgt_embedding(ys) + model.positional_encoding[:ys.size(0), :], memory, tgt_mask=tgt_mask)
out = model.fc_out(out)
_, next_word = torch.max(out[-1, :], dim=1) # Next word is at the last idx of the output. The input and output have same shape. We can regard the output as shift one step of input.
next_word = next_word.item()
# The entire sequence of generated tokens up to the newly generated next_word as Q for the transformer decoder
ys = torch.cat([ys, torch.ones(1, 1).type_as(src.data).fill_(next_word)], dim=0)
if next_word == tgt_vocab['<eos>']:
break
return ys
# Translate a sentence
model.eval()
src_sentence = "I love you"
src_tensor = torch.tensor([src_vocab["<bos>"]] + [src_vocab[word] for word in src_sentence.split()] + [src_vocab["<eos>"]])
# Generate translation
translated_sentence = greedy_decode(model, src_tensor, max_len=10, start_symbol=tgt_vocab['<bos>'])
# Convert tokens back to words
translated_words = [list(tgt_vocab.keys())[list(tgt_vocab.values()).index(idx)] for idx in translated_sentence]
print(" ".join(translated_words))
<bos> Je t'aime <eos>