BiRNN (Bidirectional RNN)

A Bidirectional RNN (BiRNN) is an extension of a standard RNN that processes the input sequence in both forward and backward directions. At each time step $t$, there are two hidden states concatenated together:

  • A forward hidden state $\overrightarrow{h_t}$, which depends on the past context $x_1, ..., x_t$
  • A backward hidden state $\overleftarrow{h_t}$, which depends on the future context $x_T, ..., x_t$

$$ h_t = \begin{bmatrix} \overrightarrow{h_t} \\ \overleftarrow{h_t} \end{bmatrix} \in \mathbb{R}^{2h} $$ which gives a richer representation that captures past and future information simultaneously.

There are three challenging points about BiRNN:

  1. How is BiRNN different from RNN?
  2. How to implement a BiRNN layer in pytorch?
  3. What are the two outputs of an BiRNN layer in pytorch, and what are the shapes?

--- BiRNN Architecture¶

Each direction has its own parameters. The computations are as follows:

  • Forward RNN. At each time step $t = 1$ to $T$: $$ \overrightarrow{h_t} = \psi(W^{(f)}_{xh} x_t + W^{(f)}_{hh} \overrightarrow{h_{t-1}} + b^{(f)}_h) $$
  • Backward RNN. At each time step $t = T$ down to $1$: $$ \overleftarrow{h_t} = \psi(W^{(b)}_{xh} x_t + W^{(b)}_{hh} \overleftarrow{h_{t+1}} + b^{(b)}_h) $$

Where:

  • $x_t \in \mathbb{R}^d$ is the input at time $t$
  • $W_{xh}$, $W_{hh}$, and $b_h$ are the weight matrices and bias for each direction
  • $\psi$ is an activation function such as tanh or ReLU

The output $y_t$ can be computed from the concatenated hidden state $h_t = \begin{bmatrix} \overrightarrow{h_t} \\ \overleftarrow{h_t} \end{bmatrix}$: $$ y_t = \phi(W_{hy} h_t + b_y) $$

Where:

  • $W_{hy} \in \mathbb{R}^{o \times 2h}$ projects the concatenated hidden state to the output space
  • $b_y \in \mathbb{R}^o$ is the output bias
  • $\phi$ is typically an activation function such as softmax (for classification) or identity (for regression)

--- BiRNN Implementation¶

This implementation builds a simple BiRNN-based model for Named Entity Recognition (NER), where each word in a sentence is classified into entity labels such as B-PER, B-LOC, or O. Given a toy vocabulary of size 6, the model uses an embedding dimension of 16, a hidden size of 32 per RNN direction (total hidden size = 64), and predicts over 3 possible tags. Each input sequence is of length 4, and training is performed using cross-entropy loss over word-level tag predictions.

In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
import random
In [3]:
# -------------------------------
# Model Definition
# -------------------------------

class BiRNN_NER(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_dim, tagset_size):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        self.rnn = nn.RNN(embedding_dim, hidden_dim, bidirectional=True, batch_first=True)
        self.classifier = nn.Linear(hidden_dim * 2, tagset_size)

    def forward(self, input_ids):
        embeds = self.embedding(input_ids)                 # (batch, seq_len, embed_dim)
        rnn_out, _ = self.rnn(embeds)                      # (batch, seq_len, 2*hidden)
        logits = self.classifier(rnn_out)                  # (batch, seq_len, tagset_size)
        return logits
In [4]:
# -------------------------------
# Toy Dataset
# -------------------------------

vocab = {'John': 0, 'lives': 1, 'in': 2, 'Paris': 3, 'Mary': 4, 'London': 5}
# B-PER: Beginning of Person entity (e.g. John, Mary)
# O: Outside of any entity (e.g. lives, in) 
# B-LOC: Beginning of Location entity (e.g. Paris, London)
tag2idx = {'B-PER': 0, 'O': 1, 'B-LOC': 2}
idx2tag = {v: k for k, v in tag2idx.items()}

sentences = [
    ['John', 'lives', 'in', 'Paris'],
    ['Mary', 'lives', 'in', 'London']
]
labels = [
    ['B-PER', 'O', 'O', 'B-LOC'],
    ['B-PER', 'O', 'O', 'B-LOC']
]

# encode() converts words and labels to tensor of indices
# Example:
# sentence = ['John', 'lives', 'in', 'Paris'] -> input_ids = tensor([0, 1, 2, 3])
# label = ['B-PER', 'O', 'O', 'B-LOC'] -> label_ids = tensor([0, 1, 1, 2])
def encode(sentence, label):
    input_ids = torch.tensor([vocab[w] for w in sentence], dtype=torch.long)
    label_ids = torch.tensor([tag2idx[l] for l in label], dtype=torch.long)
    return input_ids, label_ids

# dataset is a list of (input_ids, label_ids) tuples
# Example: [
#   (tensor([0, 1, 2, 3]), tensor([0, 1, 1, 2])),  # John lives in Paris
#   (tensor([4, 1, 2, 5]), tensor([0, 1, 1, 2]))   # Mary lives in London
# ]
dataset = [encode(s, l) for s, l in zip(sentences, labels)]


# -------------------------------
# Training Configuration
# -------------------------------

embedding_dim = 16
hidden_dim = 32
vocab_size = len(vocab)
tagset_size = len(tag2idx)

model = BiRNN_NER(vocab_size, embedding_dim, hidden_dim, tagset_size)
optimizer = optim.Adam(model.parameters(), lr=0.01)
criterion = nn.CrossEntropyLoss()
In [5]:
# -------------------------------
# Training Process
# -------------------------------

print("🔧 Training...")
for epoch in range(20):
    total_loss = 0
    model.train()
    for input_ids, label_ids in dataset:
        input_ids = input_ids.unsqueeze(0)         # (1, seq_len)
        label_ids = label_ids.unsqueeze(0)         # (1, seq_len)

        optimizer.zero_grad()
        logits = model(input_ids)                  # (1, seq_len, tagset_size)

        # reshape for loss: (batch*seq_len, tagset_size) vs (batch*seq_len)
        loss = criterion(logits.view(-1, tagset_size), label_ids.view(-1))
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    print(f"Epoch {epoch+1}: loss = {total_loss:.4f}")
🔧 Training...
Epoch 1: loss = 2.0754
Epoch 2: loss = 0.8039
Epoch 3: loss = 0.3287
Epoch 4: loss = 0.1245
Epoch 5: loss = 0.0506
Epoch 6: loss = 0.0239
Epoch 7: loss = 0.0130
Epoch 8: loss = 0.0078
Epoch 9: loss = 0.0051
Epoch 10: loss = 0.0035
Epoch 11: loss = 0.0026
Epoch 12: loss = 0.0020
Epoch 13: loss = 0.0016
Epoch 14: loss = 0.0013
Epoch 15: loss = 0.0011
Epoch 16: loss = 0.0010
Epoch 17: loss = 0.0009
Epoch 18: loss = 0.0008
Epoch 19: loss = 0.0007
Epoch 20: loss = 0.0007

--- BiRNN History And Use Cases¶

BiRNN was one of the first highly successful models to combine deep learning with bidirectional contextual encoding for sequence labeling tasks. Around 2015, the introduction of BiRNN + CRF models—especially those based on BiLSTM—marked the first time deep learning methods significantly outperformed traditional statistical approaches such as HMMs and CRFs (Huang et al. (2015), Bidirectional LSTM-CRF for Sequence Tagging). As a result, BiRNN (particularly BiLSTM) became one of the first high-performance, deep learning-based, end-to-end solutions for NLP sequence labeling tasks, and has since become a standard architecture in the field (until Transformer).

Bidirectional structures were already explored in earlier speech recognition research. For example, the first systematic application of BiRNN (specifically BiLSTM) to sequence modeling can be found in the seminal work Graves & Schmidhuber (2005): Framewise phoneme classification with bidirectional LSTM (applied to speech data)

Therefore, we can say BiRNN was among the earliest models to successfully introduce bidirectional structures at scale in NLP, achieving great success in tasks like NER and POS tagging, and has since been widely used across many types of sequence labeling tasks in NLP.

  • Named Entity Recognition (NER) Example: use BiRNNs to classify each word in a sentence into predefined entity types such as person (B-PER), location (B-LOC), or organization (B-ORG). Input is a sequence of word embeddings (or one-hot vectors), and at each time step the output is a label indicating the entity type. The forward RNN captures context from left to right, while the backward RNN captures context from right to left, allowing the model to make tagging decisions using full sentence information. For example, in “John lives in Paris”, BiRNN can better tag "Paris" as a location by seeing both “lives in” and “Paris” simultaneously.

  • Part-of-Speech (POS) Tagging Example: assign grammatical labels (e.g., noun, verb, adjective) to each word in a sentence. Input is a word sequence, and the output is a sequence of POS tags. BiRNNs are effective because the correct POS of a word often depends on both its previous and next words. For instance, the word “book” could be a noun or a verb, and BiRNN can use both directions of context to disambiguate.

  • Chunking / Shallow Parsing Example: identify phrase-level structures such as noun phrases (NP) or verb phrases (VP) in a sentence. Input is a tokenized sentence; the output is a label sequence marking chunk boundaries. BiRNNs help by providing global sentence context, enabling more accurate chunk boundary detection (e.g., distinguishing between “New York” as a noun phrase vs. separate tokens).

  • Speech and Phoneme Recognition Example: map sequences of audio frames to sequences of phoneme labels. Input is a sequence of acoustic features (e.g., MFCCs); output is a sequence of phonemes or characters. Since speech is temporally continuous, BiRNNs help improve recognition accuracy by leveraging both past and future acoustic cues.

  • Textual Emotion or Sentiment Tagging Example: assign emotion or sentiment labels to each word or subphrase within a sentence, useful in fine-grained opinion mining. Input is a sequence of word embeddings; output is a label sequence indicating emotional state (e.g., positive, negative, neutral). BiRNNs capture surrounding emotional context, allowing the model to distinguish phrases like “not good” from “good” by considering words before and after.