Transformers Basic

--- Attention mechanism

The attention mechanism allows a model to dynamically focus on different parts of the input sequence when generating each part of the output sequence, rather than treating all input tokens as equally important.

The first widely used attention mechanism was introduced in the 2014 paper “Neural Machine Translation by Jointly Learning to Align and Translate” by Bahdanau, Cho, and Bengio. It's an RNN-based attention model. (For key attention concepts and a basic example of how to compute attention scores, see the RNN-with-Attention notes)

--- Self-attention

Transformers use self-attention, whereas earlier models mainly used attention as cross-attention between the encoder and decoder.

Self-attention is still a form of attention. We'll show a detailed example in the "Self-Attention" section. It follows the standard mechanism: use Q and K to compute attention scores and weights, and then use V and the weights to compute context vectors. Self-attention allows the model to determine the importance of different parts of the input sequence when making predictions.

Takeaways:

  1. When was positional encoding first proposed?
  2. What is the usage of positional encoding, and what function is commonly used?
  3. What is the difference between cross-attention and self-attention?
  4. When was multi-head attention first proposed?
  5. In Multi-Head Attention, are the input Q, K, V for each head the same? What is the output of each head?
  6. Why must the output dimension after concatenation in Multi-Head Attention be equal to the input embedding dimension?
  7. Can parallelism be achieved in a transformer based next token prediction task? Any difference between training and inference?
  8. In the inference process of next token prediction task, why we have K, V cache but don't have Q cache?
  9. What is the difference between nn.TransformerEncoderLayer and nn.TransformerEncoder? What are the input & output of them? If I want to implement a certain functionality, can I achieve it with either one?
  10. What is the difference between nn.TransformerDecoderLayer and nn.TransformerDecoder? What are the input & output of them? If I want to implement a certain functionality, can I achieve it with either one?
  11. What is the input and output of nn.Transformer?

--- Positional Encoding¶

Positional Encoding was first introduced in the Transformer model proposed in the 2017 paper “Attention is All You Need.” Positional encoding is used to inject information about the relative or absolute position of tokens in the sequence. They proposed two types

  • fixed sinusoidal positional encoding
  • learnable positional embeddings.

Positional Encoding is introduced because the Transformer architecture removes recurrence and convolution, it lacks an inherent sense of word order. Transformers do not inherently understand the order of tokens in the input (they treat the input sequence as a set of vectors without any sequential structure). So we need positional encoding.

Implementation Highlights¶

  • Sinusoidal Functions: The positional encoding uses sinusoidal functions of different frequencies to generate a unique positional vector for each position. This method ensures that each position has a distinct encoding and that similar positions have similar encodings, capturing the sequential nature of the data. Sinusoidal functions include:

    • sin(x)
    • cos(x)
    • Or any of their variants (sin(ax), cos(bx), sin(kx+c), etc.)
  • The sine and cosine positional encoding formulas from "Attention Is All You Need": $$ PE_{(pos,2i)} = \sin\left(pos / 10000^{2i/d_{\text{model}}}\right) $$ $$ PE_{(pos,2i+1)} = \cos\left(pos / 10000^{2i/d_{\text{model}}}\right) $$ Meaning:

    • pos: position (0, 1, 2, 3, ...)
    • i: dimension index
    • $d_{\text{model}}$: embedding dimension
  • Equations to Code:

    • position: A tensor representing the positions in the sequence.

      position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
      
    • div_term: A tensor that scales the positions to generate different frequencies for the sinusoidal functions.

      The denominator in the formula: $$ 10000^{2i/d_{\text{model}}} $$ Expressed in exponential form: $$ 10000^{2i/d_{\text{model}}} = \exp\left(\log(10000) \cdot \frac{2i}{d_{\text{model}}}\right) $$ Code:

      div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
      
    • Sinusoidal Values:

      pe[:, 0::2] = torch.sin(position * div_term)
      pe[:, 1::2] = torch.cos(position * div_term)
      
      • pe[:, 0::2]: Assigns the sine function values to even indices in the encoding.
      • pe[:, 1::2]: Assigns the cosine function values to odd indices in the encoding.
    • Usage: The positional encoding is added to the input embeddings, which combines the positional information with the token representations.

      def forward(self, x):
          return x + self.pe[:x.size(0), :]
      

--- Self-Attention¶

In 2016, Google's paper "A Structured Self-Attentive Sentence Embedding" had already used self-attention for sentence encoding, but at that time it was only used as an auxiliary module for RNNs. Transformer was the first model to systematically and extensively use self-attention as its core mechanism.

What is the difference between cross-attention and self-attention? -- The way Q, K, and V are chosen.

Self-attention: Q, K, and V all come from the same sequence (such as within the encoder or decoder). Cross-attention: Q comes from the target sequence (decoder), while K and V come from the source sequence (encoder).

Self-Attention Mechanism¶

Self-attention means each element attends to itself and others in the same sequence. For a sequence,

  1. Query (Q): Represents the current word (or token) for which we are calculating the attention. (Each word will be used as Q once)
  2. Key (K): Represents all the words (or tokens) in the sequence.
  3. Value (V): Also represents all the words (or tokens) in the sequence.

Toy Math Example: Processing the Sentence "I love AI"

  1. Input Embedding

    Suppose we have a simple input sentence: "I love AI". For simplicity, let's use the following embeddings for these words:

    • "I" -> $[0.1, 0.2]$
    • "love" -> $[0.3, 0.4]$
    • "AI" -> $[0.5, 0.6]$

    So, the input sequence $X$ is:

    X = [[0.1, 0.2], [0.3, 0.4], [0.5, 0.6]]
    
  2. Positional Encoding

    To provide positional information, we add simple positional encodings:

    • Position 0 -> $[0.0, 0.1]$
    • Position 1 -> $[0.1, 0.2]$
    • Position 2 -> $[0.2, 0.3]$

    Adding these to the embeddings, we get:

    Positional Encoded X = [[0.1 + 0.0, 0.2 + 0.1], [0.3 + 0.1, 0.4 + 0.2], [0.5 + 0.2, 0.6 + 0.3]]
                      = [[0.1, 0.3], [0.4, 0.6], [0.7, 0.9]]
    
  3. Linear Transformation to Obtain Q, K, V

    We perform linear transformations to get Query (Q), Key (K), and Value (V) matrices. For simplicity, assume the following weight matrices $W_Q$, $W_K$, and $W_V$:

    W_Q = W_K = W_V = [[1, 0], [0, 1]]
    

    Therefore:

    Q = K = V = [[0.1, 0.3], [0.4, 0.6], [0.7, 0.9]]
    
  4. Compute Attention Scores

    Compute the attention scores using the dot product of Q and K^T:

    Attention scores = Q * K^T
                   = [[0.1, 0.3], [0.4, 0.6], [0.7, 0.9]] * [[0.1, 0.4, 0.7], [0.3, 0.6, 0.9]]
                   = [[(0.1*0.1 + 0.3*0.3), (0.1*0.4 + 0.3*0.6), (0.1*0.7 + 0.3*0.9)],
                      [(0.4*0.1 + 0.6*0.3), (0.4*0.4 + 0.6*0.6), (0.4*0.7 + 0.6*0.9)],
                      [(0.7*0.1 + 0.9*0.3), (0.7*0.4 + 0.9*0.6), (0.7*0.7 + 0.9*0.9)]]
                   = [[0.1, 0.2, 0.3],
                      [0.2, 0.52, 0.84],
                      [0.3, 0.84, 1.38]]
    
  5. Apply Softmax to Get Attention Weights

    Normalize the scores using the softmax function:

    Attention weights = softmax(Attention scores)
    For simplicity, let's approximate the softmax results:
                   ≈ [[0.3, 0.3, 0.4],
                      [0.2, 0.4, 0.4],
                      [0.2, 0.3, 0.5]]
    
  6. Compute the Weighted Sum of Values

    Use the attention weights to compute the weighted sum of the value vectors:

    Self-Attention output = Attention weights * V
                         = [[0.3, 0.3, 0.4], [0.2, 0.4, 0.4], [0.2, 0.3, 0.5]] * [[0.1, 0.3], [0.4, 0.6], [0.7, 0.9]]
                         = [[(0.3*0.1 + 0.3*0.4 + 0.4*0.7), (0.3*0.3 + 0.3*0.6 + 0.4*0.9)],
                            [(0.2*0.1 + 0.4*0.4 + 0.4*0.7), (0.2*0.3 + 0.4*0.6 + 0.4*0.9)],
                            [(0.2*0.1 + 0.3*0.4 + 0.5*0.7), (0.2*0.3 + 0.3*0.6 + 0.5*0.9)]]
                         = [[0.46, 0.72],
                            [0.46, 0.72],
                            [0.51, 0.78]]
    

In the above example, the calculations for attention scores, attention weights, and the weighted sums for each word ("I", "love", "AI") are independent and can be executed simultaneously.

--- Multi-head Attention¶

In the 2017 paper "Attention is All You Need", the authors first proposed Multi-Head Attention, as an extension of Self-Attention and Cross-Attention that allows the model to learn different representations from multiple subspaces in parallel. It is one of the key innovations that led to the success of the Transformer. Previous attention mechanisms (such as Bahdanau Attention, Luong Attention, or Self-Attention) were all single-head, without the concept of "multi-head".

Multi-Head Attention is like having a team of note-takers, where each person focuses on a different aspect of the same lecture: one listens for key facts, another pays attention to tone and emotion, another tracks cause-and-effect. Afterward, they combine their notes to form a more complete understanding. This allows the model to focus on different parts of the input in multiple ways, enhancing its ability to capture various aspects of the input data.

Like self-attention, Multi-Head Attention does not rely exclusively on the Transformer architecture and can also be used in RNNs.

Compute Multi-head Attention¶

Multi-head attention involves computing multiple sets of alignment scores in parallel, each with different learned linear transformations of the same Q, K, V. Each head computes attention score (usually scaled dot-product, other alignment functions are also used occasionally), and then computes a context vector as the output of each head.

$$\text{head}_i = \text{Attention}(QW_{iQ}, KW_{iK}, VW_{iV})$$

Where $W_{iQ}$, $W_{iK}$, and $W_{iV}$ are learned projection matrices for the $i$-th head. Finally we concatenate outputs of all heads and (maybe) linearly transform it with with another learned matrix:

$$\text{MultiHead}(Q, K, V) = \text{Concat}(\text{head}_1, \ldots, \text{head}_h) W_O$$

Where $W_O$ is a learned output projection matrix. In Multi-Head Attention, the final output $\text{MultiHead}(Q, K, V)$ must have the same dimension as the input embedding. This is because the MHA output is added to the input via a residual connection, and the dimensions must match: $$ \text{MHA}(x) + x $$ The final linear layer $W_O$ mainly serves this purpose. Of course, if $\text{Concat}(\text{head}_1, \ldots, \text{head}_h)$ already matches the input embedding size, $W_O$ may not be necessary.

Toy Math Example

  • Embedding Dimension (embed_dim): 4
  • Number of Heads (num_heads): 2
  • Output Dimension per Head: embed_dim / num_heads = 4 / 2 = 2

Assume we have 3 keys and values, and 1 query.

Head 1: $W_Q^{(1)} = \begin{bmatrix} 0.1 & 0.2 \\ 0.3 & 0.4 \\ 0.5 & 0.6 \\ 0.7 & 0.8 \end{bmatrix} \qquad W_K^{(1)} = \begin{bmatrix} 0.9 & 1.0 \\ 1.1 & 1.2 \\ 1.3 & 1.4 \\ 1.5 & 1.6 \end{bmatrix} \qquad W_V^{(1)} = \begin{bmatrix} 1.7 & 1.8 \\ 1.9 & 2.0 \\ 2.1 & 2.2 \\ 2.3 & 2.4 \end{bmatrix}$

Head 2: $W_Q^{(2)} = \begin{bmatrix} 2.5 & 2.6 \\ 2.7 & 2.8 \\ 2.9 & 3.0 \\ 3.1 & 3.2 \end{bmatrix}$,$W_K^{(2)} = \begin{bmatrix} 3.3 & 3.4 \\ 3.5 & 3.6 \\ 3.7 & 3.8 \\ 3.9 & 4.0 \end{bmatrix}$,$W_V^{(2)} = \begin{bmatrix} 4.1 & 4.2 \\ 4.3 & 4.4 \\ 4.5 & 4.6 \\ 4.7 & 4.8 \end{bmatrix}$

Let’s assume our input sequence for queries, keys, and values are:

  • Input Query: $\textbf{Q} = \begin{bmatrix} 0.5 & 0.6 & 0.7 & 0.8 \end{bmatrix}$

  • Input Key: $\textbf{K} = \begin{bmatrix} 0.1 & 0.2 & 0.3 & 0.4 \\ 0.5 & 0.6 & 0.7 & 0.8 \\ 0.9 & 1.0 & 1.1 & 1.2 \end{bmatrix}$

  • Input Value: $\textbf{V} = \begin{bmatrix} 1.3 & 1.4 & 1.5 & 1.6 \\ 1.7 & 1.8 & 1.9 & 2.0 \\ 2.1 & 2.2 & 2.3 & 2.4 \end{bmatrix}$

  1. Head 1 Linear Transformations

    Query:

    $$\textbf{Q}_{\text{Head 1}} = \textbf{Q} \cdot W_Q^{(1)} = \begin{bmatrix} 0.5 & 0.6 & 0.7 & 0.8 \end{bmatrix} \cdot \begin{bmatrix} 0.1 & 0.2 \\ 0.3 & 0.4 \\ 0.5 & 0.6 \\ 0.7 & 0.8 \end{bmatrix} = \begin{bmatrix} 1.44 & 1.40 \end{bmatrix}$$

    Key:

    $$\textbf{K}_{\text{Head 1}} = \textbf{K} \cdot W_K^{(1)} = \begin{bmatrix} 0.1 & 0.2 & 0.3 & 0.4 \\ 0.5 & 0.6 & 0.7 & 0.8 \\ 0.9 & 1.0 & 1.1 & 1.2 \end{bmatrix} \cdot \begin{bmatrix} 0.9 & 1.0 \\ 1.1 & 1.2 \\ 1.3 & 1.4 \\ 1.5 & 1.6 \end{bmatrix} = \begin{bmatrix} 1.30 & 1.40 \\ 3.22 & 2.48 \\ 5.14 & 5.56 \end{bmatrix}$$

    Value:

    $$\textbf{V}_{\text{Head 1}} = \textbf{V} \cdot W_V^{(1)} = \begin{bmatrix} 1.3 & 1.4 & 1.5 & 1.6 \\ 1.7 & 1.8 & 1.9 & 2.0 \\ 2.1 & 2.2 & 2.3 & 2.4 \end{bmatrix} \cdot \begin{bmatrix} 1.7 & 1.8 \\ 1.9 & 2.0 \\ 2.1 & 2.2 \\ 2.3 & 2.4 \end{bmatrix} = \begin{bmatrix} 11.70 & 12.28 \\ 15.90 & 15.64 \\ 17.10 & 18.00 \end{bmatrix}$$

  2. Head 2 Linear Transformations

    Query:

    $$\textbf{Q}_{\text{Head 2}} = \textbf{Q} \cdot W_Q^{(2)} = \begin{bmatrix} 0.5 & 0.6 & 0.7 & 0.8 \end{bmatrix} \cdot \begin{bmatrix} 2.5 & 2.6 \\ 2.7 & 2.8 \\ 2.9 & 3.0 \\ 3.1 & 3.2 \end{bmatrix} = \begin{bmatrix} 7.38 & 7.64 \end{bmatrix}$$

    Key:

    $$\textbf{K}_{\text{Head 2}} = \textbf{K} \cdot W_K^{(2)} = \begin{bmatrix} 0.1 & 0.2 & 0.3 & 0.4 \\ 0.5 & 0.6 & 0.7 & 0.8 \\ 0.9 & 1.0 & 1.1 & 1.2 \end{bmatrix} \cdot \begin{bmatrix} 3.3 & 3.4 \\ 3.5 & 3.6 \\ 3.7 & 3.8 \\ 3.9 & 4.0 \end{bmatrix} = \begin{bmatrix} 3.70 & 3.80 \\ 9.46 & 9.72 \\ 15.22 & 15.64 \end{bmatrix}$$

    Value:

    $$\textbf{V}_{\text{Head 2}} = \textbf{V} \cdot W_V^{(2)} = \begin{bmatrix} 1.3 & 1.4 & 1.5 & 1.6 \\ 1.7 & 1.8 & 1.9 & 2.0 \\ 2.1 & 2.2 & 2.3 & 2.4 \end{bmatrix} \cdot \begin{bmatrix} 4.1 & 4.2 \\ 4.3 & 4.4 \\ 4.5 & 4.6 \\ 4.7 & 4.8 \end{bmatrix} = \begin{bmatrix} 25.62 & 26.20 \\ 32.66 & 33.40 \\ 39.74 & 40.36 \end{bmatrix}$$

  3. Compute Attention Scores

    $$\text{Attention Scores}_{\text{Head 1}} = \textbf{Q}_{\text{Head 1}} \cdot \textbf{K}_{\text{Head 1}}^\top$$

    $$\text{Attention Scores}_{\text{Head 1}} = \begin{bmatrix} 1.44 \cdot 1.30 + 1.40 \cdot 1.40 \\ 1.44 \cdot 3.22 + 1.40 \cdot 2.48 \\ 1.44 \cdot 5.14 + 1.40 \cdot 5.56 \end{bmatrix} = \begin{bmatrix} 1.872 + 1.96 \\ 4.636 + 3.472 \\ 7.4256 + 7.84 \end{bmatrix} = \begin{bmatrix} 3.832 \\ 8.108 \\ 15.2656 \end{bmatrix}$$

    $$\text{Attention Scores}_{\text{Head 2}} = \textbf{Q}_{\text{Head 2}} \cdot \textbf{K}_{\text{Head 2}}^\top$$

    $$\text{Attention Scores}_{\text{Head 2}} = \begin{bmatrix} 7.38 \cdot 3.70 + 7.64 \cdot 3.80 \\ 7.38 \cdot 9.46 + 7.64 \cdot 9.72 \\ 7.38 \cdot 15.22 + 7.64 \cdot 15.64 \end{bmatrix} = \begin{bmatrix} 27.306 + 29.016 \\ 69.747 + 74.470 \\ 112.4076 + 119.0456 \end{bmatrix} = \begin{bmatrix} 56.322 \\ 144.217 \\ 231.4532 \end{bmatrix}$$

    $$\text{Softmax}_{\text{Head 1}} = \frac{e^{3.832}}{\text{Sum}}, \frac{e^{8.108}}{\text{Sum}}, \frac{e^{15.2656}}{\text{Sum}} \approx \begin{bmatrix} 0.00091 \\ 0.0645 \\ 0.9346 \end{bmatrix}$$

    $$\text{Softmax}_{\text{Head 2}} \approx \begin{bmatrix} 1.0 \times 10^{-76} \\ 1.6 \times 10^{-38} \\ 1.0 \end{bmatrix}$$

  4. Context Vectors

    $$\text{Context}_{\text{Head 1}} = \text{Softmax}_{\text{Head 1}} \cdot \textbf{V}_{\text{Head 1}}$$

    $$\text{Context}_{\text{Head 1}} = \begin{bmatrix} 0.00091 & 0.0645 & 0.9346 \end{bmatrix} \cdot \begin{bmatrix} 11.70 & 12.28 \\ 15.90 & 15.64 \\ 17.10 & 18.00 \end{bmatrix} = \begin{bmatrix} 17.0164 & 17.8313 \end{bmatrix}$$

    $$\text{Context}_{\text{Head 2}} = \text{Softmax}_{\text{Head 2}} \cdot \textbf{V}_{\text{Head 2}}$$

    $$\text{Context}_{\text{Head 2}} = \begin{bmatrix} 1.0 \times 10^{-76} & 1.6 \times 10^{-38} & 1.0 \end{bmatrix} \cdot \begin{bmatrix} 25.62 & 26.20 \\ 32.66 & 33.40 \\ 39.74 & 40.36 \end{bmatrix} = \begin{bmatrix} 39.74 & 40.36 \end{bmatrix}$$

Example Code of Multi-head Attention¶

A simplified version of how MultiheadAttention class is implemented in torch. As we showed in the previous formula, the output of MultiheadAttention must have the same dimension as the input Q, K, V.

import torch
import torch.nn as nn
import torch.nn.functional as F

class MultiheadAttention(nn.Module):
    def __init__(self, embed_dim, num_heads, dropout=0.0):
        super(MultiheadAttention, self).__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.dropout = dropout

        self.head_dim = embed_dim // num_heads
        assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"

        # Linear layers for query, key, and value
        self.q_proj = nn.Linear(embed_dim, embed_dim)
        self.k_proj = nn.Linear(embed_dim, embed_dim)
        self.v_proj = nn.Linear(embed_dim, embed_dim)
        
        # Output projection
        self.out_proj = nn.Linear(embed_dim, embed_dim)

    def forward(self, query, key, value, key_padding_mask=None, need_weights=True, attn_mask=None):
        # Project input tensors to query, key, and value
        query = self.q_proj(query)
        key = self.k_proj(key)
        value = self.v_proj(value)

        # Reshape for multi-head attention
        batch_size, seq_len, embed_dim = query.size()
        query = query.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        key = key.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        value = value.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)

        # Scaled dot-product attention
        attn_output, attn_output_weights = self.scaled_dot_product_attention(query, key, value, attn_mask, key_padding_mask)

        # Concatenate heads and project
        attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, seq_len, embed_dim)
        attn_output = self.out_proj(attn_output)

        if need_weights:
            return attn_output, attn_output_weights
        else:
            return attn_output

    def scaled_dot_product_attention(self, query, key, value, attn_mask=None, key_padding_mask=None):
        # Calculate attention scores
        attn_scores = torch.matmul(query, key.transpose(-2, -1)) / (self.head_dim ** 0.5)
        
        if attn_mask is not None:
            attn_scores = attn_scores + attn_mask

        if key_padding_mask is not None:
            attn_scores = attn_scores.masked_fill(key_padding_mask.unsqueeze(1).unsqueeze(2), float('-inf'))

        # Normalize attention scores to probabilities
        attn_weights = F.softmax(attn_scores, dim=-1)

        if self.dropout > 0.0:
            attn_weights = F.dropout(attn_weights, p=self.dropout)

        # Weighted sum of values (Each item of the output corresponds to the weighted sum of V vectors w.r.t. the certain query)
        attn_output = torch.matmul(attn_weights, value)

        return attn_output, attn_weights

--- Transformers¶

Attention can be used as a module in a variety of models such as RNNs and MLPs. However, what makes the Transformer special is that it builds its entire architecture solely based on Attention, replacing RNNs/CNNs for handling sequential data.

In almost all cases, Transformers use Multi-Head Attention and Positional Encoding.

Easy Parallelism In Transformer

In RNNs, whether it's cross-attention or self-attention, computations cannot be parallelized because the hidden states (from which the Q, K and V in attention may be picked) are generated step by step. Attention can only be computed within the currently "generated" window; you can't apply attention to the entire input sequence at once.

Transformers, on the other hand, can compute in parallel because all their Q, K, and V vectors are obtained by applying linear transformations to the entire sequence at once, without relying on the output of the previous token. For the whole sequence, each token independently generates its own Q (query vector), and each Q's attention is computed parallelly. Note that in transformers, this parallelism specifically refers to parallel attention computation.

Parallelism in transformer is the key to make models bigger! The parallel attention computation is a significant advantage of transformers compared with the original RNN attention model, making transformers highly efficient for large-scale sequence processing tasks. Transformers are the foundation for many state-of-the-art models in natural language processing, such as BERT and GPT.

Next Token Prediction Task: Slow Inference

For Next Token Prediction, Transformers can be parallelized during training because the target sequence is known, and targets at every position can be predicted simultaneously, simply by controlling the flow of information with masks. When training an autoregressive model (such as GPT or a translation model decoder), you already have the complete target output sequence, for example:

Input:      [A, B, C, D]
Target:     [B, C, D, <EOS>]

This means for each token, you know the “ground truth” next word. So you can:

  • Feed the entire target sequence into the model as decoder input, e.g., [A, B, C, D]
  • The model uses masked self-attention to predict the next word at every position simultaneously
    • The first position predicts B, the second predicts C, and so on;
    • All Q/K/V, attention, linear layers, etc., can be computed in parallel!
  • Add a mask to prevent information leakage
    • Although we feed in the full sequence, a causal mask ensures each position can only see what comes before it (it prevents peeking at the correct answer);
    • This preserves the property of autoregressive modeling.

For Next Token Prediction, Transformers are highly parallel during training, but not during inference. The inference process is essentially like an RNN: generating tokens step by step, in order, where each step depends on the output of the previous step. Specifically, at step t, Q comes from the token at position t-1, K and V are from all previous tokens, so step t can only start after step t-1 has been predicted. This style of generation cannot process each generation position in parallel.

Also, transformer inference is slow because attention computation is quadratic complexity (O(n²)).

KV Cache

KV Cache improves inference speed in next-token generation inference task by caching repeated computations of K, V.

KV Cache means storing the K and V vectors generated by linear transformations of previous tokens, so that at each inference step, you don't have to recompute these for all previous tokens. In attention computation, for input token $x_t$, we compute:

$K_t = x_t \cdot W^K$

$V_t = x_t \cdot W^V$

If you don't use KV cache, at each generation step you'd have to repeat these linear transformations for all previous $x_1, x_2, ..., x_{t-1}$, which is very inefficient.

However, KV Cache does not reduce the computational complexity of Attention, which remains O(n²)!

--- nn.Transformer Related Class¶

The key component for implementation is the nn.Transformer module. It is a complete Encoder–Decoder Transformer, internally consisting of:

  • Encoder: uses TransformerEncoderLayer (self-attention)
  • Decoder: uses TransformerDecoderLayer (masked self-attention + cross-attention) Therefore, PyTorch’s nn.Transformer fully implements the standard Transformer architecture (as in “Attention Is All You Need”) with cross-attention.

We can see that in TransformerEncoderLayer and TransformerDecoderLayer, the structure is not just a multi-head attention; instead, it is a sequence of multi-head attention followed by a feedforward network (multi-head att -> FFN). In fact, in the original “Attention Is All You Need” paper, each encoder and decoder layer in the Transformer is structured as multi-head attention followed by FFN. We may ask: Attention itself is already very powerful, so why do we still need to add an FFN after each layer? Why is it not enough to just add another activation?

Attention is essentially a weighted average calculator: it can only reorder and aggregate information, but it is not a feature transformer. What multi-head attention essentially does is use attention weights to compute weighted sums of features, so the new features created are just linear combinations of the original features. From the perspective of linear algebra, linear combinations do not increase the dimensionality of the space, so attention does not create a new feature space. On the other hand, the linear layer in the FFN can truly change the feature space dimension, and with the activation function, the transformation done by the FFN becomes nonlinear and expressive enough to fit complex functions—completely transforming the feature space into a new one.

Details of nn.TransformerEncoderLayer and nn.TransformerEncoder¶

For nn.TransformerEncoderLayer and nn.TransformerEncoder, we only need to understand the structure of nn.TransformerEncoderLayer, because nn.TransformerEncoder is simply a stack of nn.TransformerEncoderLayers.

class TransformerEncoderLayer(nn.Module):
    def __init__(self, d_model=512, nhead=8, dim_feedforward=2048, dropout=0.1):
        super().__init__()
        self.self_attn = nn.MultiheadAttention(
            embed_dim=d_model,
            num_heads=nhead,
            batch_first=True  
        )
        # FFN shape of ffn out == shape of ffn in
        self.linear1 = nn.Linear(d_model, dim_feedforward)
        self.linear2 = nn.Linear(dim_feedforward, d_model)
        # LayerNorm + Dropout
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)

    def forward(self, src, src_mask=None, src_key_padding_mask=None):
        """
        src: (N, S, E)
        src_mask: (S, S) or None
        src_key_padding_mask: (N, S) or None
        """
        attn_out, _ = self.self_attn( # shape of attn_out is the same as src: (N, S, E)
            src, src, src, # let Q=K=V=src(self-attention)
            attn_mask=src_mask,
            key_padding_mask=src_key_padding_mask
        )
        # residual connection: src + attn_out -> norm
        src = self.norm1(src + self.dropout1(attn_out))
        # residual connection: src + ffn_out -> norm
        ffn_out = self.linear2(F.relu(self.linear1(src)))
        out = self.norm2(src + self.dropout2(ffn_out))
        return out

A flow chart to visualize the nn.TransformerEncoderLayer architecture:

┌───────────────┐
src ───────▶│ Multi-Head SA │─────┐
            └───────────────┘     │
                                  ▼
                             Add src & Norm (residual connection)
                                   │
                                   ▼
            ┌──────────────────────┐
            │         FFN          │
            └──────────────────────┘
                                   │
                                   ▼
                             Add src & Norm (residual connection)

For TransformerEncoderLayer and TransformerEncoder, the output and input dimensions are exactly the same.

Module Input Shape Output Shape
TransformerEncoderLayer (N, S, E) (N, S, E)
TransformerEncoder (N, S, E) (N, S, E)

where

  • S = sequence length
  • N = batch size
  • E = embedding dimension for each token (also the dimension after linear transformation to Q, K, V).

Details of nn.TransformerDecoderLayer and nn.TransformerDecoder¶

For nn.TransformerDecoderLayer and nn.TransformerDecoder, we only need to understand the structure of nn.TransformerDecoderLayer, because nn.TransformerDecoder is simply a stack of nn.TransformerDecoderLayers.

class TransformerDecoderLayer(nn.Module):
    def __init__(self, d_model=512, nhead=8, dim_feedforward=2048, dropout=0.1):
        super().__init__()
        # 1. Masked Self-Attention (for target sequence)
        self.self_attn = nn.MultiheadAttention(
            embed_dim=d_model,
            num_heads=nhead,
            batch_first=True
        )
        # 2. Cross-Attention (attend to encoder output)
        self.multihead_attn = nn.MultiheadAttention(
            embed_dim=d_model,
            num_heads=nhead,
            batch_first=True
        )
        # 3. Feed-Forward Network (same as in encoder)
        self.linear1 = nn.Linear(d_model, dim_feedforward)
        self.linear2 = nn.Linear(dim_feedforward, d_model)
        # LayerNorm + Dropout
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.norm3 = nn.LayerNorm(d_model)
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)
        self.dropout3 = nn.Dropout(dropout)

    def forward(
        self,
        tgt,                        # (N, T, E)
        memory,                     # (N, S, E)
        tgt_mask=None,              # (T, T)
        memory_mask=None,           # (T, S)
        tgt_key_padding_mask=None,  # (N, T)
        memory_key_padding_mask=None# (N, S)
    ):
        """
        tgt: target sequence (N, T, E)
        memory: encoder output (N, S, E)
        tgt_mask: look-ahead mask for self-attention
        memory_mask: optional mask for encoder-decoder attention
        """
        # --- 1. Masked Self-Attention ---
        tgt2, _ = self.self_attn(
            tgt, tgt, tgt,
            attn_mask=tgt_mask,
            key_padding_mask=tgt_key_padding_mask
        )
        tgt = self.norm1(tgt + self.dropout1(tgt2))

        # --- 2. Cross-Attention (Q from tgt, K/V from encoder memory) ---
        tgt2, _ = self.multihead_attn(
            tgt, memory, memory,
            attn_mask=memory_mask,
            key_padding_mask=memory_key_padding_mask
        )
        tgt = self.norm2(tgt + self.dropout2(tgt2))

        # --- 3. Feed-Forward Network ---
        ffn_out = self.linear2(F.relu(self.linear1(tgt)))
        out = self.norm3(tgt + self.dropout3(ffn_out))
        return out

A flow chart to visualize the nn.TransformerDecoderLayer architecture:

┌───────────────────┐
tgt ───────▶│ Masked Self-Attn  │────┐
            └───────────────────┘    │
                                     ▼
                                Add tgt & Norm (residual connection)
                                     │
                                     ▼
            ┌───────────────────┐
memory ────▶│  Cross-Attention  │────┐
            └───────────────────┘    │
                                     ▼
                                Add tgt & Norm (residual connection)
                                     │
                                     ▼
            ┌──────────────────────┐
            │         FFN          │
            └──────────────────────┘
                                     │
                                     ▼
                                Add tgt & Norm (residual connection)

For TransformerDecoderLayer and TransformerDecoder, the output and input dimensions are exactly the same.

Module Input Shape Output Shape
TransformerDecoderLayer (N, T, E) (N, T, E)
TransformerDecoder (N, T, E) (N, T, E)

where

  • T = target sequence length
  • S = source sequence length (from the encoder output)
  • N = batch size
  • E = embedding dimension for each token (also the dimension of Q, K, V after linear transformation).

The difference between Encoder and Decoder Layers is that the Decoder Layer contains two attention blocks:

  1. Masked Self-Attention, where the decoder attends to previous tokens only.
  2. Cross-Attention, where the decoder attends to the encoder’s output.

Details of nn.Transformer¶

For nn.Transformer, we only need to understand that it is composed of both an encoder and a decoder. Internally, it combines a stack of nn.TransformerEncoderLayers and a stack of nn.TransformerDecoderLayers into a complete Transformer architecture for sequence-to-sequence modeling.

class Transformer(nn.Module):
    def __init__(self, d_model=512, nhead=8, num_encoder_layers=6, num_decoder_layers=6, dim_feedforward=2048, dropout=0.1):
        super().__init__()
        # Encoder: stack of TransformerEncoderLayers
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=d_model,
            nhead=nhead,
            dim_feedforward=dim_feedforward,
            dropout=dropout,
            batch_first=True
        )
        self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_encoder_layers)

        # Decoder: stack of TransformerDecoderLayers
        decoder_layer = nn.TransformerDecoderLayer(
            d_model=d_model,
            nhead=nhead,
            dim_feedforward=dim_feedforward,
            dropout=dropout,
            batch_first=True
        )
        self.decoder = nn.TransformerDecoder(decoder_layer, num_layers=num_decoder_layers)

        self.d_model = d_model
        self.nhead = nhead

    def forward(
        self,
        src,                        # (N, S, E)
        tgt,                        # (N, T, E)
        src_mask=None,              # (S, S)
        tgt_mask=None,              # (T, T)
        memory_mask=None,           # (T, S)
        src_key_padding_mask=None,  # (N, S)
        tgt_key_padding_mask=None,  # (N, T)
        memory_key_padding_mask=None# (N, S)
    ):
        """
        src: source sequence (input)
        tgt: target sequence (output)
        """
        # Encoder: encodes the source sequence
        memory = self.encoder(
            src,
            mask=src_mask,
            src_key_padding_mask=src_key_padding_mask
        )  # (N, S, E)

        # Decoder: generates the output sequence
        out = self.decoder(
            tgt,
            memory,
            tgt_mask=tgt_mask,
            memory_mask=memory_mask,
            tgt_key_padding_mask=tgt_key_padding_mask,
            memory_key_padding_mask=memory_key_padding_mask
        )  # (N, T, E)

        return out

A flow chart to visualize the overall nn.Transformer architecture:

┌───────────────────────┐
src ───────────▶│ Transformer Encoder   │─────┐
                └───────────────────────┘     │
                                              ▼
                                         memory (encoder output)
                                              │
                                              ▼
                                  ┌───────────────────────┐
tgt ─────────────────────────────▶│ Transformer Decoder   │─────▶ output
                                  └───────────────────────┘

For Transformer, the encoder and decoder parts have matching embedding dimensions, and both preserve input–output dimensionality.

Module Input Shape Output Shape
TransformerEncoder (N, S, E) (N, S, E)
TransformerDecoder (N, T, E) (N, T, E)
Transformer (src: N, S, E), (tgt: N, T, E) (N, T, E)

where

  • S = source sequence length
  • T = target sequence length
  • N = batch size
  • E = embedding dimension (shared between encoder and decoder)

Important parameters of nn.Transformer

Parameter Meaning Role / Corresponding Transformer Component
d_model The main dimensionality of the Transformer The dimension of each token embedding, and also the size of Q, K, V vectors after linear projection.
nhead Number of attention heads Defines how many parallel attention heads are used in the Multi-Head Attention mechanism.
num_encoder_layers Number of encoder layers Determines how many TransformerEncoderLayers are stacked to form the encoder.
num_decoder_layers Number of decoder layers Determines how many TransformerDecoderLayers are stacked to form the decoder.
dim_feedforward Hidden dimension in the feed-forward network (FFN) Typically 4× larger than d_model; controls the capacity of the FFN sub-layer.

--- Implementation: Using Transformer for Translation (Seq2Seq, Encoder and Decoder)¶

We use a seq2seq translation task here because transformer as implemented in the original paper, is a standard Seq2Seq translation model: the input is a source language sequence and the output is a target language sequence, intended for machine translation tasks.

A sequence start with <bos>, predict the next token iteratively until <eos>. The input and output have same shape. We can regard the output as shift one step of input. Next word is at the last idx of the output.

The decoder has two attention layers: the first is the self attention of the target seq (Q,K,V are all from tgt), and the second is the cross attention where Q is from tgt, K, V are from src embeddings.

As discussed before, the parallelization is limited in inference process. The model does not have access to the full target sequence. It generates one token at a time, so the prediction of the next token depends on the tokens generated so far.

In [1]:
import math
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
In [2]:
# ----------
# Data Preparation
# ----------

# Define a simple dataset
data = [
    ("I am a student", "Je suis un étudiant"),
    ("He is a teacher", "Il est un enseignant"),
    ("She is a nurse", "Elle est une infirmière"),
    ("I love you", "Je t'aime"),
    ("How are you?", "Comment ça va?"),
]

# Build a vocabulary
def build_vocab(sentences):
    vocab = {"<unk>": 0, "<pad>": 1, "<bos>": 2, "<eos>": 3}
    for sentence in sentences:
        for word in sentence.split():
            if word not in vocab:
                vocab[word] = len(vocab)
    return vocab

src_sentences = [pair[0] for pair in data]
tgt_sentences = [pair[1] for pair in data]

src_vocab = build_vocab(src_sentences)
tgt_vocab = build_vocab(tgt_sentences)
print("SRC_VOCAB_SIZE:", len(src_vocab))
print("TGT_VOCAB_SIZE ", len(tgt_vocab))
print("src_vocab: ", src_vocab)
print("tgt_vocab: ", tgt_vocab)

# Convert sentences to tensors
def sentence_to_tensor(sentence, vocab):
    return torch.tensor([vocab[word] for word in sentence.split()], dtype=torch.long)

class TranslationDataset(Dataset):
    def __init__(self, data, src_vocab, tgt_vocab):
        self.data = data
        self.src_vocab = src_vocab
        self.tgt_vocab = tgt_vocab

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        src, tgt = self.data[idx]
        src_tensor = sentence_to_tensor(src, self.src_vocab)
        tgt_tensor = sentence_to_tensor(tgt, self.tgt_vocab)
        return src_tensor, tgt_tensor

# DataLoader
BATCH_SIZE = 2
PAD_IDX = src_vocab["<pad>"]

def collate_fn(batch):
    src_batch, tgt_batch = [], []
    for src_sample, tgt_sample in batch:
        src_batch.append(torch.cat([torch.tensor([src_vocab["<bos>"]]), src_sample, torch.tensor([src_vocab["<eos>"]])]))
        tgt_batch.append(torch.cat([torch.tensor([tgt_vocab["<bos>"]]), tgt_sample, torch.tensor([tgt_vocab["<eos>"]])]))
    src_batch = nn.utils.rnn.pad_sequence(src_batch, padding_value=PAD_IDX)
    tgt_batch = nn.utils.rnn.pad_sequence(tgt_batch, padding_value=PAD_IDX)
    return src_batch, tgt_batch

train_loader = DataLoader(TranslationDataset(data, src_vocab, tgt_vocab), batch_size=BATCH_SIZE, collate_fn=collate_fn)

print("Raw data before padding: ")
for src, tgt in data:
    src_tensor = sentence_to_tensor(src, src_vocab)
    tgt_tensor = sentence_to_tensor(tgt, tgt_vocab)
    print("src_tensor: ", src_tensor)
    print("tgt_tensor: ", tgt_tensor)
SRC_VOCAB_SIZE: 18
TGT_VOCAB_SIZE  18
src_vocab:  {'<unk>': 0, '<pad>': 1, '<bos>': 2, '<eos>': 3, 'I': 4, 'am': 5, 'a': 6, 'student': 7, 'He': 8, 'is': 9, 'teacher': 10, 'She': 11, 'nurse': 12, 'love': 13, 'you': 14, 'How': 15, 'are': 16, 'you?': 17}
tgt_vocab:  {'<unk>': 0, '<pad>': 1, '<bos>': 2, '<eos>': 3, 'Je': 4, 'suis': 5, 'un': 6, 'étudiant': 7, 'Il': 8, 'est': 9, 'enseignant': 10, 'Elle': 11, 'une': 12, 'infirmière': 13, "t'aime": 14, 'Comment': 15, 'ça': 16, 'va?': 17}
Raw data before padding: 
src_tensor:  tensor([4, 5, 6, 7])
tgt_tensor:  tensor([4, 5, 6, 7])
src_tensor:  tensor([ 8,  9,  6, 10])
tgt_tensor:  tensor([ 8,  9,  6, 10])
src_tensor:  tensor([11,  9,  6, 12])
tgt_tensor:  tensor([11,  9, 12, 13])
src_tensor:  tensor([ 4, 13, 14])
tgt_tensor:  tensor([ 4, 14])
src_tensor:  tensor([15, 16, 17])
tgt_tensor:  tensor([15, 16, 17])
In [ ]:
# ----------
# Model Architecture
# ----------

class TransformerModel(nn.Module):
    def __init__(self, src_vocab_size, tgt_vocab_size, d_model=256, nhead=8, num_encoder_layers=3, num_decoder_layers=3, dim_feedforward=512, dropout=0.1):
        super(TransformerModel, self).__init__()
        self.d_model = d_model
        self.src_embedding = nn.Embedding(src_vocab_size, d_model)
        self.tgt_embedding = nn.Embedding(tgt_vocab_size, d_model)
        self.positional_encoding = self._generate_positional_encoding(d_model)
        self.transformer = nn.Transformer(d_model, nhead, num_encoder_layers, num_decoder_layers, dim_feedforward, dropout)
        self.fc_out = nn.Linear(d_model, tgt_vocab_size)
    
    def _generate_positional_encoding(self, d_model, max_len=5000):
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(1)  # Add batch dimension
        return pe

    def forward(self, src, tgt):
        src = self.src_embedding(src) * math.sqrt(self.d_model)
        tgt = self.tgt_embedding(tgt) * math.sqrt(self.d_model)
        
        # Apply positional encoding
        src = src + self.positional_encoding[:src.size(0), :]
        tgt = tgt + self.positional_encoding[:tgt.size(0), :]
        
        output = self.transformer(src, tgt)
        final_layer = self.fc_out(output)
        return final_layer
In [ ]:
# ----------
# Training
# ----------

# Training configuration
SRC_VOCAB_SIZE = len(src_vocab)
TGT_VOCAB_SIZE = len(tgt_vocab)
model = TransformerModel(SRC_VOCAB_SIZE, TGT_VOCAB_SIZE)
criterion = nn.CrossEntropyLoss(ignore_index=PAD_IDX)
optimizer = optim.Adam(model.parameters(), lr=0.0005)

# Training loop
NUM_EPOCHS = 10

for epoch in range(NUM_EPOCHS):
    model.train()
    total_loss = 0
    for src_batch, tgt_batch in train_loader:
        print("src, tgt: ", src_batch.shape, tgt_batch.shape) # src and tgt after padding
        optimizer.zero_grad()

        # Remove last token from tgt_batch for input
        tgt_input = tgt_batch[:-1, :]
        
        # Forward pass
        output = model(src_batch, tgt_input)

        # Remove the first token from tgt_batch for target
        tgt_out = tgt_batch[1:, :].reshape(-1)

        # Calculate loss
        loss = criterion(output.reshape(-1, output.shape[-1]), tgt_out)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    print(f'Epoch {epoch + 1}/{NUM_EPOCHS}, Loss: {total_loss / len(train_loader):.4f}')
/Users/yy544/.local/share/virtualenvs/Cookbook-0P5uvQVm/lib/python3.10/site-packages/torch/nn/modules/transformer.py:307: UserWarning: enable_nested_tensor is True, but self.use_nested_tensor is False because encoder_layer.self_attn.batch_first was not True(use batch_first for better inference performance)
  warnings.warn(f"enable_nested_tensor is True, but self.use_nested_tensor is False because {why_not_sparsity_fast_path}")
src, tgt:  torch.Size([6, 2]) torch.Size([6, 2])
src, tgt:  torch.Size([6, 2]) torch.Size([6, 2])
src, tgt:  torch.Size([5, 1]) torch.Size([5, 1])
Epoch 1/10, Loss: 3.0812
src, tgt:  torch.Size([6, 2]) torch.Size([6, 2])
src, tgt:  torch.Size([6, 2]) torch.Size([6, 2])
src, tgt:  torch.Size([5, 1]) torch.Size([5, 1])
Epoch 2/10, Loss: 2.2407
src, tgt:  torch.Size([6, 2]) torch.Size([6, 2])
src, tgt:  torch.Size([6, 2]) torch.Size([6, 2])
src, tgt:  torch.Size([5, 1]) torch.Size([5, 1])
Epoch 3/10, Loss: 1.8884
src, tgt:  torch.Size([6, 2]) torch.Size([6, 2])
src, tgt:  torch.Size([6, 2]) torch.Size([6, 2])
src, tgt:  torch.Size([5, 1]) torch.Size([5, 1])
Epoch 4/10, Loss: 1.3930
src, tgt:  torch.Size([6, 2]) torch.Size([6, 2])
src, tgt:  torch.Size([6, 2]) torch.Size([6, 2])
src, tgt:  torch.Size([5, 1]) torch.Size([5, 1])
Epoch 5/10, Loss: 1.2496
src, tgt:  torch.Size([6, 2]) torch.Size([6, 2])
src, tgt:  torch.Size([6, 2]) torch.Size([6, 2])
src, tgt:  torch.Size([5, 1]) torch.Size([5, 1])
Epoch 6/10, Loss: 0.8894
src, tgt:  torch.Size([6, 2]) torch.Size([6, 2])
src, tgt:  torch.Size([6, 2]) torch.Size([6, 2])
src, tgt:  torch.Size([5, 1]) torch.Size([5, 1])
Epoch 7/10, Loss: 0.5083
src, tgt:  torch.Size([6, 2]) torch.Size([6, 2])
src, tgt:  torch.Size([6, 2]) torch.Size([6, 2])
src, tgt:  torch.Size([5, 1]) torch.Size([5, 1])
Epoch 8/10, Loss: 0.3647
src, tgt:  torch.Size([6, 2]) torch.Size([6, 2])
src, tgt:  torch.Size([6, 2]) torch.Size([6, 2])
src, tgt:  torch.Size([5, 1]) torch.Size([5, 1])
Epoch 9/10, Loss: 0.2143
src, tgt:  torch.Size([6, 2]) torch.Size([6, 2])
src, tgt:  torch.Size([6, 2]) torch.Size([6, 2])
src, tgt:  torch.Size([5, 1]) torch.Size([5, 1])
Epoch 10/10, Loss: 0.1677
In [ ]:
# ----------
# Inference
# ----------

def greedy_decode(model, src, max_len, start_symbol):
    src = src.unsqueeze(1)
    src_mask = model.transformer.generate_square_subsequent_mask(src.size(0)).type(torch.bool)
    memory = model.transformer.encoder(model.src_embedding(src) + model.positional_encoding[:src.size(0), :])
    
    ys = torch.ones(1, 1).fill_(start_symbol).type(torch.long) # Start with <bos>, predict the next token iteratively until <eos>
    for i in range(max_len-1):
        tgt_mask = model.transformer.generate_square_subsequent_mask(ys.size(0)).type(torch.bool)
        out = model.transformer.decoder(model.tgt_embedding(ys) + model.positional_encoding[:ys.size(0), :], memory, tgt_mask=tgt_mask)
        out = model.fc_out(out)
        _, next_word = torch.max(out[-1, :], dim=1) # Next word is at the last idx of the output. The input and output have same shape. We can regard the output as shift one step of input.
        next_word = next_word.item()

        # The entire sequence of generated tokens up to the newly generated next_word as Q for the transformer decoder
        ys = torch.cat([ys, torch.ones(1, 1).type_as(src.data).fill_(next_word)], dim=0)
        if next_word == tgt_vocab['<eos>']:
            break
    return ys

# Translate a sentence
model.eval()
src_sentence = "I love you"
src_tensor = torch.tensor([src_vocab["<bos>"]] + [src_vocab[word] for word in src_sentence.split()] + [src_vocab["<eos>"]])

# Generate translation
translated_sentence = greedy_decode(model, src_tensor, max_len=10, start_symbol=tgt_vocab['<bos>'])

# Convert tokens back to words
translated_words = [list(tgt_vocab.keys())[list(tgt_vocab.values()).index(idx)] for idx in translated_sentence]
print(" ".join(translated_words))
<bos> Je t'aime <eos>