Fine-Tuning BERT for Text Classification: A Simple Guide

Fine-Tuning BERT for Text Classification: A Simple Guide

What is BERT?

BERT (Bidirectional Encoder Representations from Transformers) is a language model created by Google that helps computers understand the meaning of words in sentences. It was trained on a large amount of text data, allowing it to solve many language-related problems without needing to be trained from scratch every time.

What is Fine-Tuning?

Fine-tuning is the process of taking a pre-trained model like BERT and training it for a specific task, such as classifying text into different categories (e.g., spam or not spam). Since BERT already knows a lot about language, fine-tuning requires less data and computation.

How Does Fine-Tuning Work?

  1. Pre-training: BERT is first trained on two tasks:

    • Masked Language Modeling (MLM): BERT learns to predict missing words in a sentence.

    • Next Sentence Prediction (NSP): BERT learns to determine if one sentence logically follows another.

These tasks help BERT understand language in a broad sense.

  1. Fine-Tuning: After pre-training, we fine-tune BERT for a specific task like text classification. This involves using labeled data—where each text is tagged with a category (e.g., "Safe" or "Not Safe")—to teach BERT how to classify new examples.

Example: Classifying Phishing URLs

Let’s say we want to use BERT to detect if a website link is safe or phishing. We use a dataset of URLs labeled as either "bad" (phishing) or "good" (safe). The dataset we are using contains various URLs, such as (the training data used in this article can be found here: Phishing Site URLs on Kaggle):

The dataset consists of thousands of URLs labeled as either "good" (safe) or "bad" (phishing), and is divided into training, validation, and testing sets. The training set is used to adjust the model's parameters, the validation set helps fine-tune the model, and the testing set evaluates its performance on new, unseen data.

To fine-tune BERT for classifying phishing URLs, we follow these steps:

  1. Load Pre-trained BERT Model: We use a version of BERT called "bert-base-uncased" that’s available on Hugging Face, which is a platform for accessing many pre-trained models.

  2. Tokenize the URLs: Tokenization is the process of breaking down the text (in this case, URLs) into smaller pieces that BERT can understand.

  3. Load the Model with a Binary Classification Head: This is a crucial step. Instead of using BERT directly, we add a custom classification layer to adapt BERT for our specific task—classifying URLs as "Safe" or "Not Safe." To do this, we use the AutoModelForSequenceClassification class, which adds a final layer with two possible outputs. We also use dictionaries (id2label and label2id) to map between numeric IDs and descriptive labels, so that the model knows how to represent its output in a meaningful way. This ensures that our model’s predictions are interpretable.

  4. Freeze Most Parameters: BERT has over 100 million parameters, which makes training from scratch difficult. Instead, we "freeze" most of these parameters, meaning we set them as non-trainable so that they remain unchanged during fine-tuning. Specifically, we do this by iterating through the model's parameters and setting requires_grad to False for all but the pooler layer. We unfreeze the pooler layer because it is responsible for transforming the output of the last hidden layer into a representation that is useful for classification tasks. By unfreezing the pooler, we allow the model to better adapt to our specific classification task, which can improve performance. This allows us to focus training on only the task-specific parts of the model, which saves time and resources.

  5. Train the Model: We train the model using our labeled dataset, telling BERT to adjust itself so it can predict whether a given URL is safe or not. During this step, we use a Data Collator to pad the token sequences in each batch to the same length, making the training process more efficient and preventing issues that can arise with sequences of different lengths.

Code Example

Here’s a simplified Python code snippet for fine-tuning BERT using the Hugging Face library:

from transformers import AutoTokenizer, AutoModelForSequenceClassification, Trainer, TrainingArguments, DataCollatorWithPadding
from datasets import load_dataset

# Load BERT tokenizer and model
model_name = "bert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(model_name)

# Define label mappings
id2label = {0: "Safe", 1: "Not Safe"}
label2id = {"Safe": 0, "Not Safe": 1}

# Load the model with a binary classification head
model = AutoModelForSequenceClassification.from_pretrained(
    model_name, 
    num_labels=2, 
    id2label=id2label, 
    label2id=label2id
)

# Freeze all base model parameters
for name, param in model.base_model.named_parameters():
    param.requires_grad = False

# Unfreeze the pooler layer
for name, param in model.named_parameters():
    if "pooler" in name:
        param.requires_grad = True

# Load dataset
dataset = load_dataset("shawhin/phishing-site-classification")

# Map dataset labels to match model labels
def map_labels(example):
    if example["label"] == "good":
        example["label"] = 0  # Corresponds to 'Safe'
    elif example["label"] == "bad":
        example["label"] = 1  # Corresponds to 'Not Safe'
    return example

dataset = dataset.map(map_labels)

# Tokenize the data
def tokenize_function(example):
    return tokenizer(example["text"], truncation=True)

tokenized_data = dataset.map(tokenize_function, batched=True)

# Data Collator
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

# Training arguments
training_args = TrainingArguments(
    output_dir="./results",
    learning_rate=2e-5,
    per_device_train_batch_size=8,
    num_train_epochs=3,
)

# Trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_data["train"],
    eval_dataset=tokenized_data["test"],
    data_collator=data_collator
)

# Train the model
trainer.train()

Save and Use the Fine-Tuned Model

After fine-tuning the model, we can save it for future use and perform inference on new examples.

To save the model:

# Save the fine-tuned model
tokenizer.save_pretrained("./saved_model")
model.save_pretrained("./saved_model")

To use the fine-tuned model for inference:

from transformers import pipeline

# Load the saved model
model_name = "./saved_model"
inference_pipeline = pipeline("text-classification", model=model_name, tokenizer=model_name)

# Perform inference on a new URL
url = "http://example.com/suspicious-link"
result = inference_pipeline(url)
print(result)

In the example above, we save both the tokenizer and the fine-tuned model to a folder named saved_model. We then use the pipeline from Hugging Face to easily load the saved model and perform inference on a new URL, classifying it as "Safe" or "Not Safe".

Conclusion

Fine-tuning BERT allows us to adapt a powerful pre-trained language model for a specific task without the huge cost of training from scratch. In our example, the model learns to classify URLs as safe or phishing, making it useful for many text classification tasks.

  • Saves Time and Resources: Fine-tuning requires only slight adjustments to an already well-trained model, saving significant time and computational power.

  • High Performance: BERT's existing language understanding capabilities make it highly effective for tasks like detecting phishing links or performing sentiment analysis.


Author Bio

Rafal Jackiewicz is an author of books about programming in C and Java. You can find more information about him and his work on Amazon.