RNN with Attention

Attention mechanism allows the model to dynamically focus on different parts of the input sequence, rather than treating all input tokens as equally important.

Attention first appeared in the 2014 seq2seq model proposed by Bahdanau et al., which was based on RNNs. It was introduced to address the problem of fixed context bottlenecks and was the first time attention was incorporated and applied within the RNN framework. Later, it was fully developed and popularized in the Transformer model.

The attention mechanism itself should be understood independently from RNN-based attention or Transformer-based attention. Fundamentally, attention just requires inputs for Q (query), K (key), V (value), and an alignment function—that’s all you need to compute attention. Whether it's attention in RNNs or Transformers simply depends on where the Q, K, and V come from (See the Attention Mechanism Note for a brief intro of attention).

In general, we have two types of RNN models in terms of the output type.

  • Vanilla RNN: For tasks such as text classification, sentiment analysis, or time-series prediction. The input is a sequence, but the output is just a single label or a single value.

    • e.g., Input: "The movie is great" → Output: "Positive"
  • Seq2Seq RNN: The input is a sequence and the output is also a sequence, it could be with a same or different length from the input. The output sequence is generated step by step, where each output token depends on the previously generated token.

    • e.g., Input: English "The cat sits" → Output: Chinese "猫坐下了"

Attention was first introduced in the Seq2Seq RNN setup, but as we mentioned earlier, it should be understood as an independent mechanism. As long as the input is a sequence, attention can be applied.

  • Vanilla RNN (seq→label): By default, it only uses the final hidden state, but attention can be added to let the model automatically focus on the most important time steps.
  • Seq2Seq RNN (seq→seq): Attention is essentially a standard component, since the decoder needs to decide where to pay attention to at each output step.

This note will focus on the Seq2Seq RNN model with attention. Two takeaways:

  1. How to pick Q, K, V in the Seq2Seq RNN setup, and how to integrate the context vector in Seq2Seq RNN model?
  2. How to implement and train an Seq2Seq RNN attention model?

--- Key Concepts¶

  1. Attention Mechanism:

    • Alignment Function: A function calculates the alignment score between a query and a set of key-value pairs.

    • Alignment Score (also Attention Score): The output of alignment function. It has two interpretations: how well each K matches the Q; how much attention each V should receive when generating contetxt vector.

    • Attention Weight: Normalized alignment score (typically using a softmax function). It determines how much attention each V should receive when generating contetxt vector.

    • Context Vector: The weighted sum of V using attention wright. Represents the relevant information from the V needed to produce the current output.

  2. Scope of Attention:

    • Global Attention: Considers all input tokens when computing the attention weights.

    • Local Attention: Focuses on a specific subset of input tokens, typically using a sliding window or another heuristic to limit the scope.

  3. Source of Attention:

    • Self-Attention:
      Q, K, V all come from the same sequence. Each token attends to all others in the same input. Used in Transformer encoders.

    • Cross-Attention:
      Q comes from one sequence (e.g., decoder), K and V come from another (e.g., encoder). Enables interactions between sequences.

    • Multi-Modal Attention:
      A special case of cross-attention where Q and K/V come from different modalities (e.g., text vs image).

--- Seq2Seq Attention Model with Additive Alignment¶

  • Additive (Bahdanau) Attention

    The alignment score is computed by first concatenating the query vector $Q$ and the key vector $K$ and applying a linear layer on the concatenated vector, and passing it through a tanh activation function. $$\text{Score}(Q, K) = v^T \tanh( \text{concat}(W_Q Q,W_K K) )$$ Where $W_Q$ and $W_K$ are weight matrices, and $v$ is a vector that projects the combined vector into a scalar score (Notice that $v$ is a parameter that will be updated in back propagation. It's completely different from $V$). The softmax function is then applied to obtain the attention weights. $$\text{Attention}(Q, K, V) = \text{softmax}(\text{Score}) V$$

In this toy example, we use a simple sequence-to-sequence (seq2seq) model with an RNN-based encoder-decoder architecture enhanced by additive attention (Bahdanau attention). The input consists of a short sequence of 3 words: ["the", "cat", "sits"].

  • The encoder is a unidirectional RNN that processes the input sequence word by word and outputs a hidden state at each step. In our setup, each word embedding is a 3-dimensional vector, and the encoder maps it into a 2-dimensional hidden state.
  • The decoder is another RNN that generates the output sequence one step at a time. At each time step, it produces a 2-dimensional hidden state, which acts as the query for the attention mechanism. The decoder then combines this query with the attention-derived context vector (also 2D) to make a prediction.
  • The attention mechanism sits between the encoder and decoder (cross-attention). It uses additive attention** to compute a relevance score between the decoder’s current hidden state (query) and each encoder hidden state (key/value), producing a context vector as a weighted sum of encoder states.

Step 1: Encoder Hidden States¶

We assume the encoder has already processed the input words and produced hidden states (e.g., via an RNN). For simplicity, suppose each hidden state is a 2-dimensional vector. These are our encoder outputs.

  • $h_1$ for "the": $[1, 0]$
  • $h_2$ for "cat": $[0, 1]$
  • $h_3$ for "sits": $[1, 1]$

Step 2: Decoder Hidden State (Query)¶

Now, suppose the decoder is at time step $t$, and its current hidden state is as following. This will be our query (Q).

  • $q_t = [1, 1]$

Step 3: Compute Alignment Scores (Additive Attention)¶

In Additive Attention, the alignment score is computed as: $$ e_i = v^\top \tanh(W_h h_i + W_q q_t) $$ We'll choose some fake but easy matrices for easy calculation. In reality, those will be learned in optimization as parameters.

  • $W_h = \begin{bmatrix} 1 & 0 \\ 0 & 1 \end{bmatrix}$ (identity)
  • $W_q = \begin{bmatrix} 1 & 1 \\ 1 & 1 \end{bmatrix}$
  • $v = \begin{bmatrix} 1 \\ 1 \end{bmatrix}$

Alignment score for "the" (i = 1):

  • $h_1 = [1, 0]$
  • $W_h h_1 = [1, 0]$
  • $W_s s_t = [2, 2]$
  • Sum: $[1, 0] + [2, 2] = [3, 2]$
  • $\tanh([3, 2]) \approx [0.995, 0.964]$
  • $v^\top \cdot [0.995, 0.964] = 0.995 + 0.964 = 1.959$

Alignment score for "cat" (i = 2):

  • $h_2 = [0, 1]$
  • $W_h h_2 = [0, 1]$
  • Sum: $[0, 1] + [2, 2] = [2, 3]$
  • $\tanh([2, 3]) \approx [0.964, 0.995]$
  • Score: $v^\top \cdot [0.964, 0.995] = 0.964 + 0.995 = 1.959$

Alignment score for "sits" (i = 3):

  • $h_3 = [1, 1]$
  • $W_h h_3 = [1, 1]$
  • Sum: $[1, 1] + [2, 2] = [3, 3]$
  • $\tanh([3, 3]) \approx [0.995, 0.995]$
  • Score: $v^\top \cdot [0.995, 0.995] = 0.995 + 0.995 = 1.99$

Step 4: Apply Softmax to Alignment Scores¶

We now apply softmax to the alignment scores $e = [1.959, 1.959, 1.99]$: $$ \text{Softmax}(1.959, 1.959, 1.99) = \left( \frac{e^{1.959}}{Z}, \frac{e^{1.959}}{Z}, \frac{e^{1.99}}{Z} \right) $$ where $e^{1.959} \approx 7.09$, $e^{1.99} \approx 7.32$, $Z = 7.09 + 7.09 + 7.32 = 21.5$. So

  • $\alpha_1 = \alpha_2 \approx \frac{7.09}{21.5} \approx 0.33$
  • $\alpha_3 \approx \frac{7.32}{21.5} \approx 0.34$

Step 5: Compute Context Vector¶

Now, compute the context vector as the weighted sum of the encoder hidden states (V): $$ c_t = 0.33 \cdot [1, 0] + 0.33 \cdot [0, 1] + 0.34 \cdot [1, 1] $$ $$ = [0.33, 0] + [0, 0.33] + [0.34, 0.34] = [0.67, 0.67] $$

Step 5: Where Does the Context Vector Go?¶

Once we compute the context vector from attention, we don't just stop there. That vector becomes part of the input to the decoder’s next step. Specifically: $$ \text{Decoder Output} = f(s_t, c_t) $$ where:

  • $s_t$ is the current decoder hidden state
  • $c_t$ is the context vector from attention
  • $f(\cdot)$ might be a feedforward layer followed by a softmax to predict the next word

Recap¶

In the original attention mechanism (such as Bahdanau Attention, 2015), the query comes from the decoder, while the keys and values come from the encoder. This design is backed by a clear and intuitive motivation. Let’s explain this design in the context of machine translation, which was the earliest application of attention.

When translating a sentence from English to Chinese, we want the decoder to be able to selectively access the source-language context — to "look at" the parts of the input sentence that are most relevant at each step of the generation process.

Suppose we want to translate the sentence: English: "The cat sits on the mat." Chinese: "猫坐在垫子上。"

When the decoder is generating the word “垫子” (“mat”), the query is the decoder’s current hidden state. Its goal is essentially to ask: “Which part of the English sentence is related to the word I’m about to generate?”

The keys (i.e., all the encoder hidden states) serve as indexed reference points. The attention mechanism computes alignment scores between the query and each key to decide where to focus. The values (which are usually the same as the encoder outputs) contain the actual content. The attention weights are then used to compute a weighted sum over the values — producing the context vector that reflects what the decoder should focus on from the source sentence at this step. In this case, we use the encoder hidden states as both K and V.

So the process works like this: The decoder (query) asks a question, the encoder (keys) offers choices, and the encoder (values) supplies the content. This mechanism closely mimics how a human translator pays selective attention to the source sentence when producing each target word.

And if the context vector leads the model to make an incorrect prediction, the gradient signal during backpropagation essentially tells the model: “Hey, you attended to the wrong parts of the input — next time, adjust the scores so you look at the right tokens.”

--- Seq2Seq Attention Model with Additive Alignment: Implementation¶

In the previous section, we worked through a step-by-step mathematical example of how attention is computed in a Seq2Seq RNN model with additive (Bahdanau) attention. In short, this code is a toy version of the math example we just derived, showing how the same computations can be implemented in PyTorch.

In [22]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as pl
from matplotlib import rcParams

# 设置中文字体(macOS 常用)
rcParams['font.sans-serif'] = ['Arial Unicode MS']  
rcParams['axes.unicode_minus'] = False  # 避免负号显示问题
In [15]:
# -------------------------------
# Model
# -------------------------------

class Encoder(nn.Module):
    def __init__(self, embed_dim, hidden_dim):
        super().__init__()
        self.rnn = nn.RNN(embed_dim, hidden_dim, batch_first=True) 

    def forward(self, x_emb): # x_emb: (batch_size, seq_len, embed_dim)
        return self.rnn(x_emb) # out, h_T. the shape of out is (batch_size, seq_len, hidden_dim)


class AdditiveAttention(nn.Module):
    def __init__(self, hidden_dim, attn_dim):
        super().__init__()
        self.W_h = nn.Linear(hidden_dim, attn_dim)
        self.W_q = nn.Linear(hidden_dim, attn_dim)
        self.v = nn.Linear(attn_dim, 1, bias=False)

    def forward(self, decoder_state, encoder_outputs):
        # decoder_state: (batch_size, hidden_dim); encoder_outputs: (batch_size, seq_len, hidden_dim)
        seq_len = encoder_outputs.size(1) 
        decoder_exp = decoder_state.unsqueeze(1).expand(-1, seq_len, -1) # (batch_size, seq_len, hidden_dim) repeat query to match the shape of keys
        energy = torch.tanh(self.W_h(encoder_outputs) + self.W_q(decoder_exp))  # (batch_size, seq_len, attn_dim)
        scores = self.v(energy).squeeze(-1) # attention score (batch_size, seq_len) 
        weights = F.softmax(scores, dim=-1) # attention weights (batch_size, seq_len)
        context = torch.bmm(weights.unsqueeze(1), encoder_outputs).squeeze(1) # context vector (batch_size, hidden_dim)
        return context, weights


class Decoder(nn.Module):
    def __init__(self, embed_dim, hidden_dim, attn_dim, vocab_size):
        super().__init__()
        self.rnn = nn.RNN(embed_dim, hidden_dim, batch_first=True)
        self.attn = AdditiveAttention(hidden_dim, attn_dim)
        self.out = nn.Linear(hidden_dim * 2, vocab_size)

    def forward(self, input_emb, hidden, encoder_outputs): # input_emb: (batch_size, 1, embed_dim); hidden: (batch_size, 1, hidden_dim); encoder_outputs: (batch_size, seq_len, hidden_dim)
        # For seq2seq models, decoder takes a single input token at a time (usually the output of the previous time step)
        # encoder_outputs contains the hidden states of the encoder at ALL time steps. To compute the context vector.
        rnn_out, hidden = self.rnn(input_emb, hidden)
        dec_state = rnn_out.squeeze(1)
        context, attn_weights = self.attn(dec_state, encoder_outputs)
        combined = torch.cat([dec_state, context], dim=-1)  # (batch_size, dec_hidden_dim + enc_hidden_dim)
        logits = self.out(combined) # (batch_size, vocab_size)
        return logits, hidden, attn_weights
In [16]:
# -------------------------------
# Data Preparation
# -------------------------------

# Define vocabulary and create mapping dictionaries
# vocab: List of all tokens including special tokens (<pad>, <sos>, <eos>) and words in both languages
# word2idx: Maps each word to a unique integer index for model input
#   e.g. {"<pad>": 0, "<sos>": 1, "<eos>": 2, "the": 3, "cat": 4, ...}
# idx2word: Reverse mapping from indices back to words for model output
#   e.g. {0: "<pad>", 1: "<sos>", 2: "<eos>", 3: "the", 4: "cat", ...}
vocab = ["<pad>", "<sos>", "<eos>", "the", "cat", "sits", "猫", "坐", "在"]
word2idx = {w: i for i, w in enumerate(vocab)}
idx2word = {i: w for w, i in word2idx.items()}

# Embeddings (3D toy embeddings)
embedding_matrix = torch.tensor([
    [0.0, 0.0, 0.0],  # <pad>
    [0.5, 0.5, 0.5],  # <sos>
    [0.0, 0.0, 0.0],  # <eos>
    [1.0, 0.0, 1.0],  # the
    [0.0, 1.0, 0.0],  # cat
    [1.0, 1.0, 0.0],  # sits
    [1.0, 0.5, 0.0],  # 猫
    [0.0, 1.0, 1.0],  # 坐
    [1.0, 0.0, 0.5],  # 在
])
embedding_layer = nn.Embedding.from_pretrained(embedding_matrix)


input_tokens = ["the", "cat", "sits"]
target_tokens = ["<sos>", "猫", "坐", "在"]  # remove <eos> for simplicity

# Convert tokens to indices
# input_ids:  tensor([3, 4, 5])     # "the", "cat", "sits"
# target_ids: tensor([1, 6, 7, 8])  # "<sos>", "猫", "坐", "在"
input_ids = torch.tensor([word2idx[w] for w in input_tokens])
target_ids = torch.tensor([word2idx[w] for w in target_tokens])
In [17]:
# -------------------------------
# Training Configuration
# -------------------------------

# Hyperparameters
embed_dim = 3
hidden_dim = 2
attn_dim = 2
vocab_size = len(vocab)
lr = 0.01
epochs = 300

encoder = Encoder(embed_dim, hidden_dim)
decoder = Decoder(embed_dim, hidden_dim, attn_dim, vocab_size)
params = list(encoder.parameters()) + list(decoder.parameters())
optimizer = torch.optim.Adam(params, lr=lr)
criterion = nn.CrossEntropyLoss()


# -------------------------------
# Training Loop
# -------------------------------

# Training
input_emb = embedding_layer(input_ids).unsqueeze(0) # (1, seq_len, embed_dim)
for epoch in range(epochs):
    optimizer.zero_grad()
    # enc_out is used to compute the context vector; enc_hidden is used to initialize the decoder hidden state to ensure at the start of decoding, the decoder already possesses a high-level semantic representation of the input sentence.
    enc_out, enc_hidden = encoder(input_emb) # enc_out: (1, seq_len, hidden_dim); enc_hidden: (1, 1, hidden_dim)
    dec_hidden = enc_hidden
    loss = 0.0

    for t in range(len(target_ids) - 1):
        # Teacher Forcing: use the ground truth token instead of the predicted token to avoid exposure bias
        # while in inference, we have to use the predicted token to generate the next token
        dec_input_id = target_ids[t].unsqueeze(0)
        dec_input_emb = embedding_layer(dec_input_id).unsqueeze(1) # (1, 1, embed_dim)
        # predict the next token
        logits, dec_hidden, _ = decoder(dec_input_emb, dec_hidden, enc_out)
        loss += criterion(logits, target_ids[t + 1].unsqueeze(0))

    loss.backward()
    optimizer.step()
    if epoch % 50 == 0:
        print(f"Epoch {epoch}, Loss: {loss.item():.4f}")
Epoch 0, Loss: 6.6746
Epoch 50, Loss: 2.5382
Epoch 100, Loss: 1.7085
Epoch 150, Loss: 0.5144
Epoch 200, Loss: 0.2398
Epoch 250, Loss: 0.1537
In [20]:
# -------------------------------
# Inference (greedy decoding)
# -------------------------------

def greedy_decode(encoder, decoder, embedding_layer, input_ids, word2idx, idx2word, max_len=10, verbose=False):
    """
    Encodes the input sequence once with the encoder, then lets the decoder generate tokens 
    step by step by always selecting the most probable word (greedy), until <eos> is produced 
    or the maximum length is reached.

    Args:
        encoder (nn.Module): The encoder model (usually an RNN).
        decoder (nn.Module): The decoder model (with attention mechanism).
        embedding_layer (nn.Embedding): Embedding lookup layer for token IDs.
        input_ids (Tensor): Token IDs of the input sentence 
            (e.g., [3, 4, 5] corresponds to "the cat sits").
        word2idx (dict): Mapping from tokens to IDs.
        idx2word (dict): Mapping from IDs to tokens.
        max_len (int, optional): Maximum length of the generated sequence. Defaults to 10.
        verbose (bool, optional): If True, print each decoding step along with attention weights.
            Defaults to False.

    Returns:
        decoded_tokens (list of str): The generated output tokens.
        attention_maps (list of list[float]): Attention weights at each decoding step.
    """
    
    input_emb = embedding_layer(input_ids).unsqueeze(0)
    encoder_outputs, encoder_hidden = encoder(input_emb)
    decoder_hidden = encoder_hidden
    decoder_input = torch.tensor([word2idx["<sos>"]])
    decoded_tokens = []
    attention_maps = []

    for t in range(max_len):
        decoder_input_emb = embedding_layer(decoder_input).unsqueeze(1)
        logits, decoder_hidden, attn_weights = decoder(decoder_input_emb, decoder_hidden, encoder_outputs)
        pred_id = logits.argmax(dim=-1).item()
        pred_token = idx2word[pred_id]
        if verbose:
            print(f"[Step {t}] {pred_token}, attn: {attn_weights.squeeze().tolist()}")
        attention_maps.append(attn_weights.squeeze().tolist())
        if pred_token == "<eos>":
            break
        decoded_tokens.append(pred_token)
        decoder_input = torch.tensor([pred_id])

    return decoded_tokens, attention_maps

decoded, attn = greedy_decode(
    encoder, decoder, embedding_layer, input_ids, # input_ids:  tensor([3, 4, 5])     # "the", "cat", "sits"
    word2idx, idx2word, verbose=True
)
print("Generated:", decoded)
[Step 0] 猫, attn: [0.3320370614528656, 0.33273470401763916, 0.33522823452949524]
[Step 1] 坐, attn: [0.33232754468917847, 0.33286750316619873, 0.3348049521446228]
[Step 2] 在, attn: [0.33323949575424194, 0.33328646421432495, 0.3334740698337555]
[Step 3] 在, attn: [0.33301979303359985, 0.3331827223300934, 0.3337974548339844]
[Step 4] 坐, attn: [0.33262696862220764, 0.33300304412841797, 0.3343698978424072]
[Step 5] 在, attn: [0.33322328329086304, 0.3332785665988922, 0.33349815011024475]
[Step 6] 在, attn: [0.33288416266441345, 0.33312007784843445, 0.3339958190917969]
[Step 7] 坐, attn: [0.33265358209609985, 0.3330153524875641, 0.33433106541633606]
[Step 8] 在, attn: [0.3332301676273346, 0.333281934261322, 0.3334878385066986]
[Step 9] 在, attn: [0.3329254686832428, 0.3331390917301178, 0.333935409784317]
Generated: ['猫', '坐', '在', '在', '坐', '在', '在', '坐', '在', '在']
In [23]:
# -------------------------------
# Plot Attention Heatmap for Different Queries
# -------------------------------

def plot_attention(attn, input_tokens, output_tokens):
    fig, ax = plt.subplots()
    im = ax.imshow(attn, cmap="Blues")
    ax.set_xticks(range(len(input_tokens)))
    ax.set_xticklabels(input_tokens)
    ax.set_yticks(range(len(output_tokens)))
    ax.set_yticklabels(output_tokens)
    ax.set_xlabel("Input")
    ax.set_ylabel("Output")
    ax.set_title("Attention Heatmap")
    plt.colorbar(im)
    plt.show()

plot_attention(attn, input_tokens, decoded)
No description has been provided for this image

-- Other Options of Query, Key and Value¶

How to Choose Query¶

While the decoder's current hidden state is the most common choice for the query, there are several other options and variations. Here are a few alternatives:

  1. Decoder Input Embeddings

    • Instead of using the hidden state, the query can be formed from the embeddings of the decoder's current input token. This approach can sometimes provide a more direct alignment with the input features.
  2. Combination of Decoder Hidden States

    • A combination of the current and previous hidden states of the decoder can be used to form the query. This can provide richer context for computing attention weights.
    • For example, $q_t = \text{concat}(h_t, h_{t-1})$.
  3. Intermediate Layers of the Encoder

    • The query can be derived from an intermediate layer of the encoder instead of the final layer. This is particularly useful in deep encoder networks where intermediate representations might capture useful features.
  4. Learned Queries

    • In some models, queries are learned parameters. For instance, in Transformer models, queries are learned as part of the self-attention mechanism through parameter matrices applied to input embeddings.
  5. Self-Attention Mechanisms

    • In Transformer architectures, self-attention mechanisms use the same set of vectors (typically the input embeddings or their transformations) for queries, keys, and values. This means each element in the sequence attends to every other element, including itself.

How to Choose Key and Value¶

In attention mechanisms, keys and values typically come from the encoder's hidden states, but there are various other options and extensions that can be used depending on the specific model architecture and task requirements. Here are some alternatives:

  1. Intermediate Encoder Layers

    • Instead of using the final hidden states of the encoder, intermediate layers can be used as keys and values. This can capture different levels of abstraction in the input sequence.
  2. Previous Decoder States

    • In some models, previous decoder states or outputs are used as keys and values. This is common in self-attention mechanisms like in Transformers.
  3. Concatenation of Encoder States

    • Concatenating hidden states from different layers or time steps of the encoder to form richer keys and values.
  4. External Memory

    • External memory structures can be incorporated as keys and values, such as in memory-augmented neural networks.
  5. Learned Embeddings

    • Learned embeddings independent of the input sequence can be used. This approach can be seen in some variations of attention mechanisms in Transformers.
  6. Positional Encodings

    • Positional encodings can be added to the keys and values to incorporate positional information, especially in models like the Transformer.