BERT

BERT (Bidirectional Encoder Representations from Transformers) is a Transformer-based model that uses only the encoder part of a transformer. Because it does not have a decoder, we cannot do classic NLP tasks like generation or prediction by only using BERT. Those specific tasks are handled by additional task heads while BERT is mainly used for understanding the text.

BERT is pre-trained on large datasets and fine-tuned for specific applications. The pre-train is on large text corpora using two objectives:

  • Masked Language Modeling (MLM), where some words are masked and the model learns to predict them.
  • Next Sentence Prediction (NSP), where it learns the relationship between sentence pairs and determines whether sentence B is the next sentence of sentence A.

ELMo (Embeddings from Language Models, 2018, Peters et al.) initiated the era of contextualized embeddings, while BERT brought this concept to maturity, becoming the standard and peak for contextual semantic representation, and having a revolutionary impact on the entire NLP field (ELMo is based on BiLSTM).

Before contextualized embeddings, the commonly used word vectors were Word2Vec, GloVe, etc. These are static word vectors: "bank" has the same vector whether it means river bank or financial bank. Contextualized Embeddings generate dynamic word vectors, allowing the same word to have different embeddings in different contexts.

Takeaways:

  1. What are contextualized embeddings?
  2. What are the [CLS] and [SEP] tokens?
  3. What are the input and output of MLM, NSP specifically?
  4. Is BERT typically fine-tuned using full-parameter fine-tuning or partial-parameter fine-tuning?
  5. Have the parameters been modified a lot during fine tuning?

--- CLS and SEP¶

For NSP in BERT, the usage of [CLS] and [SEP] is straightforward. And for Masked Language Modeling (MLM) in BERT, the [CLS] and [SEP] tokens are still used, even though they aren't directly related to the task of predicting the masked tokens. Here's why they are included:

  1. [CLS] Token: The final hidden state of [CLS] represents the entire sentence, which can be useful in downstream tasks such as text classification. Including [CLS] during MLM pre-training ensures that the model's pre-training process aligns with fine-tuning tasks where [CLS] is necessary (e.g., sentence classification).

  2. [SEP] Token: In tasks where multiple sentences are involved, the [SEP] token helps differentiate them. Even in MLM, including [SEP] ensures that the model is trained to recognize sentence boundaries, which is crucial for some downstream tasks.

--- Masked Language Modeling¶

Step 1: Original Sentence¶

Consider the sentence:

  • "The quick brown fox jumps over the lazy dog."
Step 2: Masking Tokens¶

In the MLM task, some of the tokens in the sentence are randomly selected and replaced with a [MASK] token. Let's say we randomly choose the words "quick" and "lazy" to mask:

  • "[CLS] The [MASK] brown fox jumps over the [MASK] dog. [SEP]"
Step 3: Input to BERT¶

This masked sentence is then fed into BERT. BERT's job during pre-training is to predict the original words that were masked out. So the input to BERT would look like:

  • Tokens: ["[CLS]", "The", "[MASK]", "brown", "fox", "jumps", "over", "the", "[MASK]", "dog", ".", "[SEP]"]
  • Target Predictions: ["quick", "lazy"]
Step 4: Prediction¶

After processing the sentence, BERT generates a probability distribution over its vocabulary for each masked position. It then predicts the most likely word for each [MASK]. In our case:

  • Predicted word for the first [MASK]: "quick"
  • Predicted word for the second [MASK]: "lazy"
Step 5: Loss Calculation¶

The predicted words ("quick" and "lazy") are compared to the actual masked words. The difference between the predicted probabilities and the actual words is used to compute the loss, which is then backpropagated to update BERT’s parameters.

Masked Language Modeling Example Code¶

In [ ]:
import torch
from transformers import BertTokenizer, BertForMaskedLM

# Load pre-trained BERT tokenizer and model
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertForMaskedLM.from_pretrained('bert-base-uncased')
/Users/yy544/.local/share/virtualenvs/Cookbook-0P5uvQVm/lib/python3.10/site-packages/transformers/tokenization_utils_base.py:1601: FutureWarning: `clean_up_tokenization_spaces` was not set. It will be set to `True` by default. This behavior will be depracted in transformers v4.45, and will be then set to `False` by default. For more details check this issue: https://github.com/huggingface/transformers/issues/31884
  warnings.warn(
A parameter name that contains `beta` will be renamed internally to `bias`. Please use a different name to suppress this warning.
A parameter name that contains `gamma` will be renamed internally to `weight`. Please use a different name to suppress this warning.
A parameter name that contains `beta` will be renamed internally to `bias`. Please use a different name to suppress this warning.
A parameter name that contains `gamma` will be renamed internally to `weight`. Please use a different name to suppress this warning.
A parameter name that contains `beta` will be renamed internally to `bias`. Please use a different name to suppress this warning.
A parameter name that contains `gamma` will be renamed internally to `weight`. Please use a different name to suppress this warning.
A parameter name that contains `beta` will be renamed internally to `bias`. Please use a different name to suppress this warning.
A parameter name that contains `gamma` will be renamed internally to `weight`. Please use a different name to suppress this warning.
A parameter name that contains `beta` will be renamed internally to `bias`. Please use a different name to suppress this warning.
A parameter name that contains `gamma` will be renamed internally to `weight`. Please use a different name to suppress this warning.
A parameter name that contains `beta` will be renamed internally to `bias`. Please use a different name to suppress this warning.
A parameter name that contains `gamma` will be renamed internally to `weight`. Please use a different name to suppress this warning.
A parameter name that contains `beta` will be renamed internally to `bias`. Please use a different name to suppress this warning.
A parameter name that contains `gamma` will be renamed internally to `weight`. Please use a different name to suppress this warning.
A parameter name that contains `beta` will be renamed internally to `bias`. Please use a different name to suppress this warning.
A parameter name that contains `gamma` will be renamed internally to `weight`. Please use a different name to suppress this warning.
A parameter name that contains `beta` will be renamed internally to `bias`. Please use a different name to suppress this warning.
A parameter name that contains `gamma` will be renamed internally to `weight`. Please use a different name to suppress this warning.
A parameter name that contains `beta` will be renamed internally to `bias`. Please use a different name to suppress this warning.
A parameter name that contains `gamma` will be renamed internally to `weight`. Please use a different name to suppress this warning.
A parameter name that contains `beta` will be renamed internally to `bias`. Please use a different name to suppress this warning.
A parameter name that contains `gamma` will be renamed internally to `weight`. Please use a different name to suppress this warning.
A parameter name that contains `beta` will be renamed internally to `bias`. Please use a different name to suppress this warning.
A parameter name that contains `gamma` will be renamed internally to `weight`. Please use a different name to suppress this warning.
A parameter name that contains `beta` will be renamed internally to `bias`. Please use a different name to suppress this warning.
A parameter name that contains `gamma` will be renamed internally to `weight`. Please use a different name to suppress this warning.
A parameter name that contains `beta` will be renamed internally to `bias`. Please use a different name to suppress this warning.
A parameter name that contains `gamma` will be renamed internally to `weight`. Please use a different name to suppress this warning.
A parameter name that contains `beta` will be renamed internally to `bias`. Please use a different name to suppress this warning.
A parameter name that contains `gamma` will be renamed internally to `weight`. Please use a different name to suppress this warning.
A parameter name that contains `beta` will be renamed internally to `bias`. Please use a different name to suppress this warning.
A parameter name that contains `gamma` will be renamed internally to `weight`. Please use a different name to suppress this warning.
A parameter name that contains `beta` will be renamed internally to `bias`. Please use a different name to suppress this warning.
A parameter name that contains `gamma` will be renamed internally to `weight`. Please use a different name to suppress this warning.
A parameter name that contains `beta` will be renamed internally to `bias`. Please use a different name to suppress this warning.
A parameter name that contains `gamma` will be renamed internally to `weight`. Please use a different name to suppress this warning.
A parameter name that contains `beta` will be renamed internally to `bias`. Please use a different name to suppress this warning.
A parameter name that contains `gamma` will be renamed internally to `weight`. Please use a different name to suppress this warning.
A parameter name that contains `beta` will be renamed internally to `bias`. Please use a different name to suppress this warning.
A parameter name that contains `gamma` will be renamed internally to `weight`. Please use a different name to suppress this warning.
A parameter name that contains `beta` will be renamed internally to `bias`. Please use a different name to suppress this warning.
A parameter name that contains `gamma` will be renamed internally to `weight`. Please use a different name to suppress this warning.
A parameter name that contains `beta` will be renamed internally to `bias`. Please use a different name to suppress this warning.
A parameter name that contains `gamma` will be renamed internally to `weight`. Please use a different name to suppress this warning.
A parameter name that contains `beta` will be renamed internally to `bias`. Please use a different name to suppress this warning.
A parameter name that contains `gamma` will be renamed internally to `weight`. Please use a different name to suppress this warning.
A parameter name that contains `beta` will be renamed internally to `bias`. Please use a different name to suppress this warning.
A parameter name that contains `gamma` will be renamed internally to `weight`. Please use a different name to suppress this warning.
A parameter name that contains `beta` will be renamed internally to `bias`. Please use a different name to suppress this warning.
A parameter name that contains `gamma` will be renamed internally to `weight`. Please use a different name to suppress this warning.
A parameter name that contains `beta` will be renamed internally to `bias`. Please use a different name to suppress this warning.
A parameter name that contains `gamma` will be renamed internally to `weight`. Please use a different name to suppress this warning.
Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForMaskedLM: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight', 'cls.seq_relationship.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
In [5]:
# Step 1: Original Sentence
sentence = "The quick brown fox jumps over the lazy dog."

# Step 2: Masking Tokens
# Tokenize the input sentence
input_ids = tokenizer.encode(sentence, return_tensors='pt')
print("input ids: ", input_ids)

# Mask the tokens "quick" and "lazy"
masked_indices = [2, 8]  # Indices of "quick" and "lazy" in tokenized sentence
input_ids[0, masked_indices] = tokenizer.mask_token_id
print("input ids after mask: ", input_ids) # the id of [MASK] is 103

# Convert input IDs back to tokens to show the masking
masked_sentence = tokenizer.convert_ids_to_tokens(input_ids[0])
print("Masked Sentence:", ' '.join(masked_sentence))
input ids:  tensor([[  101,  1996,  4248,  2829,  4419, 14523,  2058,  1996, 13971,  3899,
          1012,   102]])
input ids after mask:  tensor([[  101,  1996,   103,  2829,  4419, 14523,  2058,  1996,   103,  3899,
          1012,   102]])
Masked Sentence: [CLS] the [MASK] brown fox jumps over the [MASK] dog . [SEP]
In [7]:
# Step 3: Feed Masked Sentence into BERT
with torch.no_grad():
    outputs = model(input_ids)
    predictions = outputs.logits

# Step 4: Prediction
# Get the predicted token ids for the masked positions
predicted_ids = torch.argmax(predictions, dim=-1)
print("predicted_ids: ", predicted_ids) # we can see the prediction for first mask is 2210; for the second mask is also 2210
predicted_tokens = tokenizer.convert_ids_to_tokens(predicted_ids[0, masked_indices])

print("Predicted Tokens:", predicted_tokens)

# Step 5: Compare Predictions with Actual Tokens
actual_tokens = ["quick", "lazy"]
print("Actual Tokens:   ", actual_tokens)
predicted_ids:  tensor([[ 1012,  1996,  2210,  2829,  4419, 14523,  2058,  1996,  2210,  3899,
          1012,  1012]])
Predicted Tokens: ['little', 'little']
Actual Tokens:    ['quick', 'lazy']

--- Next Sentence Prediction¶

Step 1: Example Sentences¶

Consider the following two sentences:

  • Sentence A: "The quick brown fox jumps over the lazy dog."
  • Sentence B: "It then runs into the forest."
Step 2: Creating a Sentence Pair¶

In the NSP task, BERT is given two sentences and is trained to predict whether the second sentence (Sentence B) is the actual next sentence that follows the first one (Sentence A) in the original text.

  • Positive Example (IsNext): Sentence A is followed by Sentence B.

    • Sentence Pair: ("The quick brown fox jumps over the lazy dog.", "It then runs into the forest.")
  • Negative Example (NotNext): Sentence B is randomly selected and does not follow Sentence A in the original text.

    • Sentence Pair: ("The quick brown fox jumps over the lazy dog.", "The sun sets in the west.")
Step 3: Input to BERT¶

For the NSP task, both sentences are combined into a single input with special tokens:

  • [CLS] at the beginning of the sequence (used for classification tasks).
  • [SEP] to separate Sentence A and Sentence B.

For example, for the positive example:

  • Input: [CLS] The quick brown fox jumps over the lazy dog. [SEP] It then runs into the forest. [SEP]
Step 4: BERT's Processing¶

BERT processes the entire input sequence and outputs a representation for each token. The representation of the [CLS] token is particularly important, as it is used for the NSP classification.

Step 5: Prediction¶

The [CLS] token's representation is passed through a simple classification layer that outputs probabilities for two classes:

  • IsNext (1): The second sentence follows the first.
  • NotNext (0): The second sentence does not follow the first.

Next Sentence Prediction Example Code¶

In [9]:
from transformers import BertForNextSentencePrediction

# Load pre-trained BERT tokenizer and model for NSP
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertForNextSentencePrediction.from_pretrained('bert-base-uncased')
A parameter name that contains `beta` will be renamed internally to `bias`. Please use a different name to suppress this warning.
A parameter name that contains `gamma` will be renamed internally to `weight`. Please use a different name to suppress this warning.
A parameter name that contains `beta` will be renamed internally to `bias`. Please use a different name to suppress this warning.
A parameter name that contains `gamma` will be renamed internally to `weight`. Please use a different name to suppress this warning.
A parameter name that contains `beta` will be renamed internally to `bias`. Please use a different name to suppress this warning.
A parameter name that contains `gamma` will be renamed internally to `weight`. Please use a different name to suppress this warning.
A parameter name that contains `beta` will be renamed internally to `bias`. Please use a different name to suppress this warning.
A parameter name that contains `gamma` will be renamed internally to `weight`. Please use a different name to suppress this warning.
A parameter name that contains `beta` will be renamed internally to `bias`. Please use a different name to suppress this warning.
A parameter name that contains `gamma` will be renamed internally to `weight`. Please use a different name to suppress this warning.
A parameter name that contains `beta` will be renamed internally to `bias`. Please use a different name to suppress this warning.
A parameter name that contains `gamma` will be renamed internally to `weight`. Please use a different name to suppress this warning.
A parameter name that contains `beta` will be renamed internally to `bias`. Please use a different name to suppress this warning.
A parameter name that contains `gamma` will be renamed internally to `weight`. Please use a different name to suppress this warning.
A parameter name that contains `beta` will be renamed internally to `bias`. Please use a different name to suppress this warning.
A parameter name that contains `gamma` will be renamed internally to `weight`. Please use a different name to suppress this warning.
A parameter name that contains `beta` will be renamed internally to `bias`. Please use a different name to suppress this warning.
A parameter name that contains `gamma` will be renamed internally to `weight`. Please use a different name to suppress this warning.
A parameter name that contains `beta` will be renamed internally to `bias`. Please use a different name to suppress this warning.
A parameter name that contains `gamma` will be renamed internally to `weight`. Please use a different name to suppress this warning.
A parameter name that contains `beta` will be renamed internally to `bias`. Please use a different name to suppress this warning.
A parameter name that contains `gamma` will be renamed internally to `weight`. Please use a different name to suppress this warning.
A parameter name that contains `beta` will be renamed internally to `bias`. Please use a different name to suppress this warning.
A parameter name that contains `gamma` will be renamed internally to `weight`. Please use a different name to suppress this warning.
A parameter name that contains `beta` will be renamed internally to `bias`. Please use a different name to suppress this warning.
A parameter name that contains `gamma` will be renamed internally to `weight`. Please use a different name to suppress this warning.
A parameter name that contains `beta` will be renamed internally to `bias`. Please use a different name to suppress this warning.
A parameter name that contains `gamma` will be renamed internally to `weight`. Please use a different name to suppress this warning.
A parameter name that contains `beta` will be renamed internally to `bias`. Please use a different name to suppress this warning.
A parameter name that contains `gamma` will be renamed internally to `weight`. Please use a different name to suppress this warning.
A parameter name that contains `beta` will be renamed internally to `bias`. Please use a different name to suppress this warning.
A parameter name that contains `gamma` will be renamed internally to `weight`. Please use a different name to suppress this warning.
A parameter name that contains `beta` will be renamed internally to `bias`. Please use a different name to suppress this warning.
A parameter name that contains `gamma` will be renamed internally to `weight`. Please use a different name to suppress this warning.
A parameter name that contains `beta` will be renamed internally to `bias`. Please use a different name to suppress this warning.
A parameter name that contains `gamma` will be renamed internally to `weight`. Please use a different name to suppress this warning.
A parameter name that contains `beta` will be renamed internally to `bias`. Please use a different name to suppress this warning.
A parameter name that contains `gamma` will be renamed internally to `weight`. Please use a different name to suppress this warning.
A parameter name that contains `beta` will be renamed internally to `bias`. Please use a different name to suppress this warning.
A parameter name that contains `gamma` will be renamed internally to `weight`. Please use a different name to suppress this warning.
A parameter name that contains `beta` will be renamed internally to `bias`. Please use a different name to suppress this warning.
A parameter name that contains `gamma` will be renamed internally to `weight`. Please use a different name to suppress this warning.
A parameter name that contains `beta` will be renamed internally to `bias`. Please use a different name to suppress this warning.
A parameter name that contains `gamma` will be renamed internally to `weight`. Please use a different name to suppress this warning.
A parameter name that contains `beta` will be renamed internally to `bias`. Please use a different name to suppress this warning.
A parameter name that contains `gamma` will be renamed internally to `weight`. Please use a different name to suppress this warning.
A parameter name that contains `beta` will be renamed internally to `bias`. Please use a different name to suppress this warning.
A parameter name that contains `gamma` will be renamed internally to `weight`. Please use a different name to suppress this warning.
A parameter name that contains `beta` will be renamed internally to `bias`. Please use a different name to suppress this warning.
A parameter name that contains `gamma` will be renamed internally to `weight`. Please use a different name to suppress this warning.
A parameter name that contains `beta` will be renamed internally to `bias`. Please use a different name to suppress this warning.
A parameter name that contains `gamma` will be renamed internally to `weight`. Please use a different name to suppress this warning.
In [12]:
# Step 1: Example Sentences
sentence_a = "The quick brown fox jumps over the lazy dog."
sentence_b_isnext = "It then runs into the forest."  # Positive example
sentence_b_notnext = "The sun sets in the west."  # Negative example

# Step 2: Tokenization and Creating Input for BERT
positive_example = tokenizer.encode_plus(sentence_a, sentence_b_isnext, return_tensors='pt')
negative_example = tokenizer.encode_plus(sentence_a, sentence_b_notnext, return_tensors='pt')
print("positive_example: ", positive_example) # 101: CLS; 102: SEP
print("negative_example: ", negative_example) # 101: CLS; 102: SEP

# Step 3: Predict Next Sentence for Positive Example
with torch.no_grad():
    positive_outputs = model(**positive_example)
    negative_outputs = model(**negative_example)
    print("positive_outputs: ", positive_outputs)
    print("negative_outputs: ", negative_outputs)

# Step 4: Extract Predictions
positive_logits = positive_outputs.logits
negative_logits = negative_outputs.logits
print("positive_logits: ", positive_logits)
print("negative_logits: ", negative_logits)

# Step 5: Display Results
positive_prediction = torch.argmax(positive_logits, dim=-1).item()
negative_prediction = torch.argmax(negative_logits, dim=-1).item()

print(f"Positive Example Prediction (1 = IsNext, 0 = NotNext): {positive_prediction}")
print(f"Negative Example Prediction (1 = IsNext, 0 = NotNext): {negative_prediction}")
positive_example:  {'input_ids': tensor([[  101,  1996,  4248,  2829,  4419, 14523,  2058,  1996, 13971,  3899,
          1012,   102,  2009,  2059,  3216,  2046,  1996,  3224,  1012,   102]]), 'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])}
negative_example:  {'input_ids': tensor([[  101,  1996,  4248,  2829,  4419, 14523,  2058,  1996, 13971,  3899,
          1012,   102,  1996,  3103,  4520,  1999,  1996,  2225,  1012,   102]]), 'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])}
positive_outputs:  NextSentencePredictorOutput(loss=None, logits=tensor([[ 6.0892, -5.8013]]), hidden_states=None, attentions=None)
negative_outputs:  NextSentencePredictorOutput(loss=None, logits=tensor([[ 5.4474, -5.0226]]), hidden_states=None, attentions=None)
positive_logits:  tensor([[ 6.0892, -5.8013]])
negative_logits:  tensor([[ 5.4474, -5.0226]])
Positive Example Prediction (1 = IsNext, 0 = NotNext): 0
Negative Example Prediction (1 = IsNext, 0 = NotNext): 0

--- A Toy BERT Implementation¶

We define a simple BERT model with an embedding layer, multiple Transformer encoder layers, and two linear layers for Masked Language Modeling (MLM) and Next Sentence Prediction (NSP).

In [14]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import numpy as np
In [ ]:
# Define the BERT Model
class BERTModel(nn.Module):
    def __init__(self, vocab_size, hidden_size, num_layers, num_heads):
        super(BERTModel, self).__init__()
        self.embedding = nn.Embedding(vocab_size, hidden_size)
        self.transformer_blocks = nn.ModuleList([
            nn.TransformerEncoderLayer(hidden_size, num_heads) for _ in range(num_layers)
        ])
        self.linear_mlm = nn.Linear(hidden_size, vocab_size)
        self.linear_nsp = nn.Linear(hidden_size, 2)

    def forward(self, input_ids, attention_mask=None):
        x = self.embedding(input_ids)
        for transformer in self.transformer_blocks:
            x = transformer(x, src_key_padding_mask=attention_mask) # transformer encoder performs self-attention on src (Q,K,V all src)
        mlm_logits = self.linear_mlm(x)
        nsp_logits = self.linear_nsp(x[:, 0, :])  # Use [CLS] token representation for NSP
        return mlm_logits, nsp_logits
In [16]:
# Define a toy dataset
class ToyDataset(Dataset):
    def __init__(self, vocab_size, num_samples, max_len):
        self.vocab_size = vocab_size
        self.num_samples = num_samples
        self.max_len = max_len
        self.data = self._generate_data()

    def _generate_data(self):
        data = []
        for _ in range(self.num_samples):
            tokens = np.random.randint(0, self.vocab_size, self.max_len)
            label = np.random.randint(0, 2)
            data.append((tokens, label))
        return data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        tokens, label = self.data[idx]
        return torch.tensor(tokens), torch.tensor(label)
In [31]:
# Hyperparameters
vocab_size = 30522  # Size of the vocabulary
hidden_size = 128  # Hidden size of the transformer
num_layers = 2  # Number of transformer layers
num_heads = 4  # Number of attention heads
num_samples = 1000  # Number of samples in the dataset
max_len = 40  # Maximum length of a sequence
batch_size = 40  # Batch size
num_epochs = 2  # Number of epochs

# Prepare dataset and dataloader
dataset = ToyDataset(vocab_size, num_samples, max_len)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

# Initialize the model, loss function and optimizer
model = BERTModel(vocab_size, hidden_size, num_layers, num_heads)
criterion_mlm = nn.CrossEntropyLoss()
criterion_nsp = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)
In [39]:
# Training loop
for epoch in range(num_epochs):
    model.train()
    total_loss = 0
    i = 0
    for batch in dataloader:
        i += 1
        input_ids, labels = batch
        if epoch == 1 and i == 1:
            print("input_ids shape: ", input_ids.shape)
            print("input_ids: ", input_ids)
            print("labels shape: ", labels.shape)
            print("input_labels: ", labels)
        attention_mask = input_ids == 0  # Mask for padding tokens
        if epoch == 1 and i == 1:
            print("attention mask_shape: ", attention_mask.shape)
            # print("attention_mask: ", attention_mask)
        mlm_logits, nsp_logits = model(input_ids, attention_mask)
        if epoch == 1 and i == 1:
            print("mlm_logits shape: ", mlm_logits.shape) # torch.Size([40, 40, 30522]). 30522 is the vocab size. Every token will get a dim of 30522 prob vector, where the max corresponds to the prediction.
            # print("mlm_logits: ", mlm_logits)
            print("nsp_logits shape: ", nsp_logits.shape)
            # print("nsp_logits: ", nsp_logits)


        # Create fake masked tokens for MLM task
        masked_indices = torch.rand(input_ids.shape).round().bool()
        mlm_labels = input_ids.clone()
        mlm_labels[~masked_indices] = -100  # We only compute loss on masked tokens. assign -100 to non-masked indices
        if epoch == 1 and i == 1:
            print("shape of mlm_labels: ", mlm_labels.shape)
            print("mlm_labels: ", mlm_labels)

        # Compute the losses
        loss_mlm = criterion_mlm(mlm_logits.view(-1, vocab_size), mlm_labels.view(-1))
        loss_nsp = criterion_nsp(nsp_logits, labels)
        loss = loss_mlm + loss_nsp

        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
    print("total number of batch: ", i)

    avg_loss = total_loss / len(dataloader)
    print(f"Epoch {epoch+1}/{num_epochs}, Loss: {avg_loss:.4f}")
    print("\n")

print("Training complete!")
total number of batch:  25
Epoch 1/5, Loss: 7.8066


input_ids shape:  torch.Size([40, 40])
input_ids:  tensor([[15075,  9582, 13458,  ..., 28467, 15463, 25064],
        [  154, 27279,  3701,  ...,  6136, 26707,  7985],
        [ 7060, 16128, 24674,  ...,  9765,  2063, 15723],
        ...,
        [17502, 15297,  7671,  ..., 27725, 18108,  6494],
        [ 9838, 17169,  3905,  ..., 30275, 14923, 18063],
        [24345, 15202, 14857,  ..., 11169, 22190, 30484]])
labels shape:  torch.Size([40])
input_labels:  tensor([1, 1, 0, 0, 0, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 0, 1, 0, 1, 0,
        1, 0, 1, 0, 0, 0, 1, 0, 1, 1, 1, 1, 0, 0, 1, 0])
attention mask_shape:  torch.Size([40, 40])
mlm_logits shape:  torch.Size([40, 40, 30522])
nsp_logits shape:  torch.Size([40, 2])
shape of mlm_labels:  torch.Size([40, 40])
mlm_labels:  tensor([[ -100,  -100, 13458,  ..., 28467,  -100, 25064],
        [  154, 27279,  -100,  ...,  6136, 26707,  -100],
        [ -100,  -100,  -100,  ...,  -100,  2063,  -100],
        ...,
        [17502, 15297,  -100,  ...,  -100, 18108,  -100],
        [ 9838,  -100,  3905,  ..., 30275,  -100, 18063],
        [24345, 15202,  -100,  ..., 11169, 22190, 30484]])
total number of batch:  25
Epoch 2/5, Loss: 7.7383


total number of batch:  25
Epoch 3/5, Loss: 7.6447


total number of batch:  25
Epoch 4/5, Loss: 7.5684


total number of batch:  25
Epoch 5/5, Loss: 7.4909


Training complete!

--- Fine Tune BERT¶

During most of the fine-tuning of BERT, the parameters that are updated include all of the model's weights, specifically:

  • All parameters of BERT are fine-tuned, including the transformer layers and embedding layers, to adapt the model to the specific task.
  • The classification layer (or any task-specific layers added during fine-tuning) is trained from scratch during the fine-tuning process.

Fine-tuning BERT is relatively fast (< 5 epochs) compared to training a model from scratch for several reasons:

  • Well-Calibrated Initialization: The pre-trained weights are already well-calibrated, meaning they are close to a good solution for many tasks. Fine-tuning doesn't have to search as extensively through the parameter space, which speeds up the optimization process.

  • Less Data Required: Fine-tuning typically involves much smaller datasets compared to the enormous datasets used during pre-training. The model doesn't need to learn language patterns from scratch; it only needs to adjust to task-specific nuances.

Notice that when we are fine-tuning BERT, we use a very small learning rate (e.g., e-5). The reason is not that the loss landscape is very steep so we need to control the step size, but rather because BERT's parameters are already in a complex but very narrow basin. If the weights change too much, the pretrained language structure will be destroyed, and the loss will quickly increase.

Toy Fine Tuning Implementation¶

In the following example, we fine tune the BertForSequenceClassification model. The class BertForSequenceClassification is embedded in the transformers library. It extends BertPreTrainedModel, including a BertModel and a classification head. A glimpse of its architecture:

from transformers import BertPreTrainedModel, BertModel

class BertForSequenceClassification(BertPreTrainedModel):
    def __init__(self, config):
        super().__init__(config)
        self.num_labels = config.num_labels

        self.bert = BertModel(config)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
        self.classifier = nn.Linear(config.hidden_size, config.num_labels)

        self.init_weights()
In [ ]:
import torch
from transformers import BertTokenizer, BertForSequenceClassification
from torch.utils.data import DataLoader, TensorDataset
from sklearn.model_selection import train_test_split
import numpy as np

# Load pre-trained BERT model and tokenizer
model = BertForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=2)
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

# Sample dataset (replace with your own dataset)
reviews = [
    "This movie was fantastic! I really enjoyed it.",
    "The acting was terrible and the plot was boring.",
    "I loved the characters, but the pacing felt off.",
    "Overall, it was a decent film with some flaws."
]
labels = [1, 0, 1, 0]  # 1 for positive sentiment, 0 for negative sentiment
A parameter name that contains `beta` will be renamed internally to `bias`. Please use a different name to suppress this warning.
A parameter name that contains `gamma` will be renamed internally to `weight`. Please use a different name to suppress this warning.
A parameter name that contains `beta` will be renamed internally to `bias`. Please use a different name to suppress this warning.
A parameter name that contains `gamma` will be renamed internally to `weight`. Please use a different name to suppress this warning.
A parameter name that contains `beta` will be renamed internally to `bias`. Please use a different name to suppress this warning.
A parameter name that contains `gamma` will be renamed internally to `weight`. Please use a different name to suppress this warning.
A parameter name that contains `beta` will be renamed internally to `bias`. Please use a different name to suppress this warning.
A parameter name that contains `gamma` will be renamed internally to `weight`. Please use a different name to suppress this warning.
A parameter name that contains `beta` will be renamed internally to `bias`. Please use a different name to suppress this warning.
A parameter name that contains `gamma` will be renamed internally to `weight`. Please use a different name to suppress this warning.
A parameter name that contains `beta` will be renamed internally to `bias`. Please use a different name to suppress this warning.
A parameter name that contains `gamma` will be renamed internally to `weight`. Please use a different name to suppress this warning.
A parameter name that contains `beta` will be renamed internally to `bias`. Please use a different name to suppress this warning.
A parameter name that contains `gamma` will be renamed internally to `weight`. Please use a different name to suppress this warning.
A parameter name that contains `beta` will be renamed internally to `bias`. Please use a different name to suppress this warning.
A parameter name that contains `gamma` will be renamed internally to `weight`. Please use a different name to suppress this warning.
A parameter name that contains `beta` will be renamed internally to `bias`. Please use a different name to suppress this warning.
A parameter name that contains `gamma` will be renamed internally to `weight`. Please use a different name to suppress this warning.
A parameter name that contains `beta` will be renamed internally to `bias`. Please use a different name to suppress this warning.
A parameter name that contains `gamma` will be renamed internally to `weight`. Please use a different name to suppress this warning.
A parameter name that contains `beta` will be renamed internally to `bias`. Please use a different name to suppress this warning.
A parameter name that contains `gamma` will be renamed internally to `weight`. Please use a different name to suppress this warning.
A parameter name that contains `beta` will be renamed internally to `bias`. Please use a different name to suppress this warning.
A parameter name that contains `gamma` will be renamed internally to `weight`. Please use a different name to suppress this warning.
A parameter name that contains `beta` will be renamed internally to `bias`. Please use a different name to suppress this warning.
A parameter name that contains `gamma` will be renamed internally to `weight`. Please use a different name to suppress this warning.
A parameter name that contains `beta` will be renamed internally to `bias`. Please use a different name to suppress this warning.
A parameter name that contains `gamma` will be renamed internally to `weight`. Please use a different name to suppress this warning.
A parameter name that contains `beta` will be renamed internally to `bias`. Please use a different name to suppress this warning.
A parameter name that contains `gamma` will be renamed internally to `weight`. Please use a different name to suppress this warning.
A parameter name that contains `beta` will be renamed internally to `bias`. Please use a different name to suppress this warning.
A parameter name that contains `gamma` will be renamed internally to `weight`. Please use a different name to suppress this warning.
A parameter name that contains `beta` will be renamed internally to `bias`. Please use a different name to suppress this warning.
A parameter name that contains `gamma` will be renamed internally to `weight`. Please use a different name to suppress this warning.
A parameter name that contains `beta` will be renamed internally to `bias`. Please use a different name to suppress this warning.
A parameter name that contains `gamma` will be renamed internally to `weight`. Please use a different name to suppress this warning.
A parameter name that contains `beta` will be renamed internally to `bias`. Please use a different name to suppress this warning.
A parameter name that contains `gamma` will be renamed internally to `weight`. Please use a different name to suppress this warning.
A parameter name that contains `beta` will be renamed internally to `bias`. Please use a different name to suppress this warning.
A parameter name that contains `gamma` will be renamed internally to `weight`. Please use a different name to suppress this warning.
A parameter name that contains `beta` will be renamed internally to `bias`. Please use a different name to suppress this warning.
A parameter name that contains `gamma` will be renamed internally to `weight`. Please use a different name to suppress this warning.
A parameter name that contains `beta` will be renamed internally to `bias`. Please use a different name to suppress this warning.
A parameter name that contains `gamma` will be renamed internally to `weight`. Please use a different name to suppress this warning.
A parameter name that contains `beta` will be renamed internally to `bias`. Please use a different name to suppress this warning.
A parameter name that contains `gamma` will be renamed internally to `weight`. Please use a different name to suppress this warning.
A parameter name that contains `beta` will be renamed internally to `bias`. Please use a different name to suppress this warning.
A parameter name that contains `gamma` will be renamed internally to `weight`. Please use a different name to suppress this warning.
A parameter name that contains `beta` will be renamed internally to `bias`. Please use a different name to suppress this warning.
A parameter name that contains `gamma` will be renamed internally to `weight`. Please use a different name to suppress this warning.
A parameter name that contains `beta` will be renamed internally to `bias`. Please use a different name to suppress this warning.
A parameter name that contains `gamma` will be renamed internally to `weight`. Please use a different name to suppress this warning.
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
In [47]:
# Tokenize input texts
tokenized_texts = [tokenizer.encode(review, max_length=128, padding='max_length', truncation=True) for review in reviews]
print("tokenized text len: ", len(tokenized_texts))
print("tokenized_texts: ", tokenized_texts)

# Convert tokenized texts to tensors
input_ids = torch.tensor(tokenized_texts)
labels = torch.tensor(labels)
print("input_ids shape: ", input_ids.shape)
print("labels shape: ", labels.shape)

# Create TensorDataset
dataset = TensorDataset(input_ids, labels)

# Split dataset into train and test sets
train_dataset, test_dataset = train_test_split(dataset, test_size=0.2, random_state=42)

# DataLoader for training and testing
train_loader = DataLoader(train_dataset, batch_size=2, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=2, shuffle=False)
tokenized text len:  4
tokenized_texts:  [[101, 2023, 3185, 2001, 10392, 999, 1045, 2428, 5632, 2009, 1012, 102, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [101, 1996, 3772, 2001, 6659, 1998, 1996, 5436, 2001, 11771, 1012, 102, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [101, 1045, 3866, 1996, 3494, 1010, 2021, 1996, 15732, 2371, 2125, 1012, 102, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [101, 3452, 1010, 2009, 2001, 1037, 11519, 2143, 2007, 2070, 21407, 1012, 102, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]
input_ids shape:  torch.Size([4, 128])
labels shape:  torch.Size([4])
/var/folders/fp/y0hq044j4kn05w2f076_9bd00000gn/T/ipykernel_77659/3290858731.py:8: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
  labels = torch.tensor(labels)
In [56]:
# Training
optimizer = torch.optim.Adam(model.parameters(), lr=1e-5)
model.train()
for epoch in range(1):
    i = 0
    for batch in train_loader:
        i += 1
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        batch = tuple(t.to(device) for t in batch)
        input_ids, labels = batch
        if i == 1:
            print("input_ids shape: ", input_ids.shape)
            # print("input_ids: ", input_ids)
            print("labels shape: ", labels.shape)
            # print("labels: ", labels)
        outputs = model(input_ids, labels=labels)
        # print("outputs shape ", outputs.shape)
        print("outputs: ", outputs)
        print("\n")
        loss = outputs.loss
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

# Evaluation
model.eval()
correct = 0
total = 0
with torch.no_grad():
    for batch in test_loader:
        batch = tuple(t.to(device) for t in batch)
        input_ids, labels = batch
        outputs = model(input_ids) # for eval, we don't input labels here
        print("eval outputs: ", outputs)
        _, predicted = torch.max(outputs.logits, 1)
        print("predicted: ", predicted)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

accuracy = correct / total
print(f'Accuracy: {accuracy}')
input_ids shape:  torch.Size([2, 128])
labels shape:  torch.Size([2])
outputs:  SequenceClassifierOutput(loss=tensor(0.3951, grad_fn=<NllLossBackward0>), logits=tensor([[ 0.1361, -0.2867],
        [-0.4043,  0.7001]], grad_fn=<AddmmBackward0>), hidden_states=None, attentions=None)


outputs:  SequenceClassifierOutput(loss=tensor(0.2212, grad_fn=<NllLossBackward0>), logits=tensor([[-0.6949,  0.7010]], grad_fn=<AddmmBackward0>), hidden_states=None, attentions=None)


eval outputs:  SequenceClassifierOutput(loss=None, logits=tensor([[ 0.2154, -0.1707]]), hidden_states=None, attentions=None)
predicted:  tensor([0])
Accuracy: 1.0