Deep Learning: Training Our First Model

Discover how we fine‑tuned BERT on the AG News dataset with PyTorch’s efficient data pipelines, mixed‑precision training, and AdamW. Achieve 94.6% accuracy, explore a confusion matrix of class mix‑ups, and grab the ready‑to‑use model artefacts.

AI Modelling with BERT and the AG News dataset

In our previous lab, we explored classic machine learning by training a decision tree on the Iris dataset... learning how to clean data, select features, fit a model, and evaluate its accuracy. This time, we're advancing to a transformer‑based NLP with BERT on the AG News dataset.

First, I needed to make sure Python could access my GPU. Here’s a quick script:

import torch

print(torch.cuda.is_available())
print(torch.cuda.get_device_name(0))

Essentially, all this does is to leverage the torch library to check if the CUDA library is available and if it is what the device name is.

If you are struggling to get the cuda library to return "True" and your GPU name, then I'd recommend:

So, now we've validated our environment can indeed access our GPU it's time to start making it do something other than render game frames 🙈

📂 Code Repository: Explore the complete code and configurations for this article series on GitHub.

View Repository on GitHub

Training the Model

The plan is to build a news headline classification service from the BERT Base model (smaller and easier to work with) that given a bunch of headlines will help categorise them into the right categories

Okay, firstly, I wanted to map out the flow of what we needed to do... basically we'll load AG News dataset, tokenise and batch it, load BERT onto the GPU, train for five epochs, then evaluate and save our model.

Loading the AG News Dataset

Let's get started by loading the dataset from Hugging Face, luckily for us there's a dataset module in pip that we can leverage for just this purpose 🍀

from datasets import load_dataset

# load the dataset and extract the text and labels
dataset = load_dataset("ag_news")
train_texts = dataset['train']['text']
train_labels = dataset['train']['label']

So what we do here is pull in the ag_news dataset and load it into our dataset, the structure of which provides 2 sets of data ... the text and the labels associated with the headlines.

Tokenising our dataset

Next up we need to tokenise our text which is done using the bert-base-uncased tokeniser which matches the model we'll be training against.

# load the tokeniser and generate our tokens
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
tokens = tokenizer(
  train_texts, 
  padding=True, 
  truncation=True, 
  return_tensors="pt"
)

Here we're asking BERT's tokeniser to pad and truncate each example so that every batch has the same length which is crucial for faster processing.

NB: a tokeniser is essentially a translation mechanism that converts words to numbers, the models understand numbers; see what-is-tokenisation for more information.

Next up is to convert our set of tokens into a TensorDataset (the bundle):

# convert our lables to a tensor and build the dataset bundle
labels = torch.tensor(train_labels)
dataset = TensorDataset(tokens['input_ids'], tokens['attention_mask'], labels)
train_loader = DataLoader(dataset, batch_size=32, shuffle=True, num_workers=4, pin_memory=True)

For this part, we're setting up the dataset and ensuring the training data loader is configured correctly from the labels and the tokenised outputs.

NB: A TensorDataset bundle of tensors from the tokenisation phase that allows PyTorch to treat them as a single unit; see what-is-a-tensordataset for more information.

Load the optimiser and train the model

Sweet, now we're ready to create our model and kick off the training... here's one I made earlier 😆

# define our model as bert-base-uncased
model = AutoModelForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=4)

# train it with our dataset :)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

optimizer = AdamW(model.parameters(), lr=5e-5)
scaler = GradScaler()

model.train()
for epoch in range(5):
  epoch_loss = 0
  print(f"\n🚀 Epoch {epoch+1}")
  for batch in tqdm(train_loader, desc=f"Epoch {epoch+1}"):
    input_ids, attention_mask, labels = [x.to(device) for x in batch]

    optimizer.zero_grad()

    with autocast(device_type=device.type):
      outputs = model(input_ids, attention_mask=attention_mask, labels=labels)
      loss = outputs.loss

    scaler.scale(loss).backward()
    scaler.step(optimizer)
    scaler.update()

    epoch_loss += loss.item()
  avg_loss = epoch_loss / len(train_loader)
  print(f"✅ Epoch {epoch+1} complete — Avg Loss: {avg_loss:.4f}")

Once we have everything set up, training proceeds in five full iterations (or epochs) over the AG News dataset. During each pass, the examples are grouped into small batches and sent to the GPU for speed. The model makes predictions on each batch and computes a loss value measuring its error. Using mixed‑precision arithmetic, that loss is back‑propagated through the network to calculate gradients, which the optimiser then uses to tweak the model's internal weights.

Meanwhile, the code keeps a running sum of all the loss values so that, once every batch in the epoch has been processed, it can divide by the number of batches and report an average loss... letting us see how training is improving (or not) from one epoch to the next.

Evaluating the Model

Once training wrapped up, we ran our fine‑tuned BERT on the held‑out test set and achieved 94.6% accuracy. To see where the model still stumbled, we plotted the confusion matrix:

Confusion Matrix for the AG News dataset
Confusion Matrix for the AG News dataset

From the matrix, it's clear that "Business" and "Sci/Tech" headlines occasionally get mixed up... around 110 "Business" examples were misclassified as "Sci/Tech", and about 126 "Sci/Tech" as "Business"... likely because tech articles often mention companies and markets.

This being said, "World" and "Sports" labels remain highly reliable, each with fewer than 60 total errors out of nearly 1,900 samples.

Next Time...

Next time, we'll look into supplementing our model with some new categories through LoRA training and we'll look at techniques for making our models faster and more optimised.

Concepts Deep Dive

If you're interested in understanding a bit more about tokenisation, what a tensor dataset is or how an optimiser works then check out the following sections .. they are a bit more involved though but hopefully consumable...

What is Tokenisation?

The tokeniser is like a translator... it takes human-readable text and turns it into something the model understands: numbers. Specifically, it:

  • Splits text into smaller known pieces called subwords or Wordpiece tokenisation (e.g., "unbelievable" → "un", "##believ", "##able")
  • Maps each subword to a **token ID** from a pretrained vocabulary
  • Adds **special tokens** like `[CLS]` (start) and `[SEP]` (end)
  • Applies **padding** to make all sequences the same length
  • Outputs a tensor dictionary with:
    • `input_ids`: token IDs
    • `attention_mask`: 1's for real tokens, 0's for padding

This step ensures the text is aligned with the way the pretrained BERT model was trained... it's the bridge between language and math.

Example: Sentence "Playing video games is fun."

StepResult
Raw inputPlaying Video Games is fun.
Lowercasedplaying video games is fun.
WordPiece tokenization[CLS], play, ##ing, video, games, is, fun, ., [SEP]
Token IDs[101, 2378, 1475, 3040, 2267, 2003, 7023, 1012, 102]
Attention mask (1 = real token)[1, 1, 1, 1, 1, 1, 1, 1, 1]
Token type IDs (all single‑sentence)[0, 0, 0, 0, 0, 0, 0, 0, 0]
  • [CLS] (101) and [SEP] (102) wrap the sequence for BERT classification
  • play → 2378; ##ing → 1475 splits the verb into subwords
  • video (3040) & games (2267) are in‑vocab tokens
  • Attention mask flags all tokens as real (no padding)1
  • Token type IDs are all zeros since this is a single‑sentence input

1Padding is applied when we set a fixed lengths for tokens, essentially fills out the rest with zeros e.g. a fixed length of 10 would produce [1, 1, 1, 1, 1, 1, 1, 1, 0, 0]

What is a TensorDataset?

After tokenisation, we have multiple tensors: `input_ids`, `attention_mask`, and `labels`.

A TensorDataset simply bundles these together so PyTorch can treat them as a single unit, effectively like rows in a spreadsheet:

  • Each item = one row of input → (`input_ids`, `attention_mask`, `label`)
  • When paired with a `DataLoader`, it supports batching, shuffling, and efficient iteration

This format is essential for PyTorch to feed batches into the model during training.

What is an Optimiser?

An optimiser in deep learning is the algorithm that adjusts your model’s parameters (its weights and biases) so that it learns to make better predictions. Think of training as a teacher giving the model "graded homework" (the loss), and the optimiser as the student's strategy for revising its answers to improve next time.

So why do we need one... well, in a nutshell:

  • Loss tells us how bad we did: After a forward pass, we compute a loss (e.g. cross‑entropy for classification)
  • Gradient tells us which way to move: Backpropagation computes the gradient of the loss with respect to each parameter, indicating how a small change in that parameter would affect the loss
  • Optimiser steps in: It uses those gradients to decide how much and in which direction to update each parameter
OptimiserKey IdeaProsCons
SGDθ ← θ − η · ∇θL (vanilla gradient descent)Simple, easy to implementCan be slow to converge; sensitive to learning rate
MomentumAccumulates an exponentially decaying average of past gradients to "smooth" updates.Speeds up convergence in consistent directionsAdds extra hyperparameter (momentum)
AdamCombines momentum + adaptive learning rates per parameter (estimates first & second moments of gradients).Fast convergence, works out of the box for many tasksMore memory; sometimes over‑fits if not regularized
AdamWVariant of Adam that decouples weight decay (L2 regularization) from the gradient update.Better generalization; preferred for transformers like BERTSlightly more complex hyperparameter tuning