Learning AI: MNIST in PyTorch

This notebook walks through a complete beginner workflow using the Hugging Face MNIST dataset and CPU-only PyTorch.

The flow is:

  1. import libraries
  2. load the dataset
  3. preprocess images into tensors
  4. build DataLoaders
  5. define and train a simple fully connected network
  6. save and reload that model
  7. inspect predictions and errors
  8. replace it with a CNN and compare results
In [29]:
# Import the libraries used throughout the notebook
from datasets import load_dataset
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms
import numpy as np
from torchvision.transforms import ToTensor, Normalize, Compose
import matplotlib.pyplot as plt

1) Load the dataset

load_dataset("mnist") returns a DatasetDict with train and test splits. Each example has:

  • image: a PIL image
  • label: the correct digit (0-9)
In [10]:
# Download/load the MNIST dataset from Hugging Face
dataset = load_dataset("mnist")

2) Preprocess the images

ToTensor() converts each PIL image into a PyTorch tensor with shape [1, 28, 28] and scales pixel values to [0, 1].

A new field called pixel_values is added so the original image field is still available if needed.

In [11]:
to_tensor = ToTensor()

def preprocess(example):
    example["pixel_values"] = to_tensor(example["image"])
    return example

dataset = dataset.map(preprocess)

3) Tell Hugging Face to return PyTorch tensors

After set_format(...), pixel_values and label come back as PyTorch tensors, which makes them work cleanly with a DataLoader.

In [12]:
# Make Hugging Face return torch tensors for these columns
dataset.set_format(type="torch", columns=["pixel_values", "label"])

4) Build the DataLoaders

  • train_loader shuffles training data
  • test_loader keeps test order stable

The batch size controls how many images the model sees before each optimizer step.

In [16]:
# Force CPU
train_loader = torch.utils.data.DataLoader(
    dataset["train"],
    batch_size=64,
    shuffle=True
)

test_loader = torch.utils.data.DataLoader(
    dataset["test"],
    batch_size=1000,
    shuffle=False
)

5) Sanity-check one batch

This confirms that batching worked and that the shapes are what the model expects:

  • images: [batch, channels, height, width]
  • labels: [batch]
In [17]:
batch = next(iter(train_loader))
print(batch["pixel_values"].shape)
print(batch["label"].shape)
torch.Size([64, 1, 28, 28])
torch.Size([64])

6) Define the first model: a fully connected network

This model flattens each 28x28 image into a length-784 vector, then runs it through two linear layers.

This is a good starter model because it is simple, but it does not understand image structure very well.

In [18]:
# Define a simple fully connected baseline model
device = torch.device("cpu")

class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(28 * 28, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = x.view(x.size(0), 28 * 28)
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x

model = Net().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

7) Check the model output shape

A classifier for MNIST should return one score per digit class, so the expected shape is [batch_size, 10].

In [19]:
batch = next(iter(train_loader))
output = model(batch["pixel_values"].to(device))
print(output.shape)
torch.Size([64, 10])

8) Try one training step

Before doing a full epoch, it is useful to prove that:

  • the forward pass works
  • the loss computes
  • backpropagation works
  • the optimizer can update the model
In [20]:
batch = next(iter(train_loader))

data = batch["pixel_values"].to(device)
target = batch["label"].to(device)

optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)

print("loss:", loss.item())

loss.backward()
optimizer.step()
loss: 2.3129141330718994

9) Train the fully connected model for one epoch

This runs through the full training set once and reports the average training loss.

In [21]:
model.train()
running_loss = 0.0

for batch in train_loader:
    data = batch["pixel_values"].to(device)
    target = batch["label"].to(device)

    optimizer.zero_grad()
    output = model(data)
    loss = criterion(output, target)
    loss.backward()
    optimizer.step()

    running_loss += loss.item()

print("average loss:", running_loss / len(train_loader))
average loss: 0.3460213233675085

10) Evaluate the fully connected model

Evaluation turns off gradient tracking and measures classification accuracy on the test set.

In [22]:
model.eval()
correct = 0
total = 0

with torch.no_grad():
    for batch in test_loader:
        data = batch["pixel_values"].to(device)
        target = batch["label"].to(device)

        output = model(data)
        pred = output.argmax(dim=1)

        correct += (pred == target).sum().item()
        total += target.size(0)

print("accuracy:", correct / total)
accuracy: 0.9443

11) Continue training for a few more epochs

This shows how loss usually falls while test accuracy improves over time.

In [23]:
for epoch in range(3):
    model.train()
    running_loss = 0.0

    for batch in train_loader:
        data = batch["pixel_values"].to(device)
        target = batch["label"].to(device)

        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

    print(f"epoch {epoch + 1} loss: {running_loss / len(train_loader):.4f}")

    model.eval()
    correct = 0
    total = 0

    with torch.no_grad():
        for batch in test_loader:
            data = batch["pixel_values"].to(device)
            target = batch["label"].to(device)

            output = model(data)
            pred = output.argmax(dim=1)

            correct += (pred == target).sum().item()
            total += target.size(0)

    print(f"epoch {epoch + 1} accuracy: {correct / total:.4f}")
epoch 1 loss: 0.1609
epoch 1 accuracy: 0.9611
epoch 2 loss: 0.1119
epoch 2 accuracy: 0.9677
epoch 3 loss: 0.0837
epoch 3 accuracy: 0.9703

12) Save the fully connected model

state_dict() stores the learned weights so the model can be reloaded later.

In [24]:
torch.save(model.state_dict(), "mnist_fc.pt")

13) Reload the saved fully connected model

A fresh model instance is created, then the saved weights are loaded into it.

In [25]:
loaded_model = Net().to(device)
loaded_model.load_state_dict(torch.load("mnist_fc.pt", map_location=device))
loaded_model.eval()
Out[25]:
Net(
  (fc1): Linear(in_features=784, out_features=128, bias=True)
  (fc2): Linear(in_features=128, out_features=10, bias=True)
)

14) Verify the reloaded model

This should match the earlier test accuracy if save/load worked correctly.

In [26]:
loaded_model.eval()
correct = 0
total = 0

with torch.no_grad():
    for batch in test_loader:
        data = batch["pixel_values"].to(device)
        target = batch["label"].to(device)

        output = loaded_model(data)
        pred = output.argmax(dim=1)

        correct += (pred == target).sum().item()
        total += target.size(0)

print("loaded model accuracy:", correct / total)
loaded model accuracy: 0.9703

15) Make a single prediction

This is the simplest inference example: run one test image through the model and compare prediction vs. label.

In [27]:
sample = dataset["test"][0]
image = sample["pixel_values"].unsqueeze(0).to(device)
label = sample["label"].item()

loaded_model.eval()
with torch.no_grad():
    output = loaded_model(image)
    pred = output.argmax(dim=1).item()

print("predicted:", pred)
print("actual:", label)
predicted: 7
actual: 7

16) Display that sample

Showing the image makes it easier to connect model output to what the digit actually looks like.

In [30]:
sample = dataset["test"][0]
image = sample["pixel_values"].squeeze().numpy()
label = sample["label"].item()

loaded_model.eval()
with torch.no_grad():
    output = loaded_model(sample["pixel_values"].unsqueeze(0).to(device))
    pred = output.argmax(dim=1).item()

plt.imshow(image, cmap="gray")
plt.title(f"predicted: {pred}, actual: {label}")
plt.axis("off")
plt.show()
No description has been provided for this image

17) Show several predictions at once

A small prediction grid is a fast way to spot correct predictions and obvious mistakes.

In [31]:
loaded_model.eval()

fig, axes = plt.subplots(2, 5, figsize=(10, 4))

for i, ax in enumerate(axes.flat):
    sample = dataset["test"][i]
    image = sample["pixel_values"]
    label = sample["label"].item()

    with torch.no_grad():
        output = loaded_model(image.unsqueeze(0).to(device))
        pred = output.argmax(dim=1).item()

    ax.imshow(image.squeeze().numpy(), cmap="gray")
    ax.set_title(f"p:{pred} a:{label}")
    ax.axis("off")

plt.tight_layout()
plt.show()
No description has been provided for this image

18) Find the first misclassified example

Instead of only looking at successful predictions, this cell searches the test set until it finds a mistake.

In [32]:
loaded_model.eval()

for i in range(len(dataset["test"])):
    sample = dataset["test"][i]
    image = sample["pixel_values"]
    label = sample["label"].item()

    with torch.no_grad():
        output = loaded_model(image.unsqueeze(0).to(device))
        pred = output.argmax(dim=1).item()

    if pred != label:
        print("index:", i, "pred:", pred, "actual:", label)

        import matplotlib.pyplot as plt
        plt.imshow(image.squeeze().numpy(), cmap="gray")
        plt.title(f"predicted: {pred}, actual: {label}")
        plt.axis("off")
        plt.show()

        break
index: 8 pred: 2 actual: 5
No description has been provided for this image

19) Inspect the model's top guesses for that mistake

Looking at the top probabilities helps distinguish:

  • confidently wrong predictions
  • uncertain predictions where several classes look plausible
In [33]:
sample = dataset["test"][i]
image = sample["pixel_values"]
label = sample["label"].item()

loaded_model.eval()
with torch.no_grad():
    output = loaded_model(image.unsqueeze(0).to(device))
    probs = torch.softmax(output, dim=1).squeeze()

top_probs, top_indices = torch.topk(probs, 3)

print("actual:", label)
for rank in range(3):
    print(f"choice {rank+1}: digit {top_indices[rank].item()} prob {top_probs[rank].item():.4f}")
actual: 5
choice 1: digit 2 prob 0.5017
choice 2: digit 6 prob 0.2242
choice 3: digit 5 prob 0.2184

20) Find the most confident wrong prediction

This is especially interesting because it reveals where the model is overconfident.

In [34]:
loaded_model.eval()

worst_i = None
worst_pred = None
worst_label = None
worst_conf = -1.0

with torch.no_grad():
    for j in range(len(dataset["test"])):
        sample = dataset["test"][j]
        image = sample["pixel_values"]
        label = sample["label"].item()

        output = loaded_model(image.unsqueeze(0).to(device))
        probs = torch.softmax(output, dim=1).squeeze()
        pred = probs.argmax().item()
        conf = probs[pred].item()

        if pred != label and conf > worst_conf:
            worst_i = j
            worst_pred = pred
            worst_label = label
            worst_conf = conf

print("index:", worst_i)
print("predicted:", worst_pred)
print("actual:", worst_label)
print("confidence:", worst_conf)
index: 6166
predicted: 3
actual: 9
confidence: 0.9987390637397766

21) Display that confidently wrong example

This particular case motivated the switch from a fully connected model to a CNN.

In [35]:
sample = dataset["test"][6166]
image = sample["pixel_values"].squeeze().numpy()

import matplotlib.pyplot as plt
plt.imshow(image, cmap="gray")
plt.title("predicted: 3, actual: 9, conf: 0.9987")
plt.axis("off")
plt.show()
No description has been provided for this image

22) Replace the model with a CNN

A convolutional neural network keeps spatial structure and usually works much better on images than a fully connected network.

This architecture uses:

  • two convolution layers
  • max pooling
  • two fully connected layers at the end
In [36]:
# Define a convolutional neural network for better image handling
class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 16, kernel_size=3)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=3)
        self.pool = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(32 * 5 * 5, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = torch.relu(self.conv1(x))   # 28 -> 26
        x = self.pool(x)                # 26 -> 13
        x = torch.relu(self.conv2(x))   # 13 -> 11
        x = self.pool(x)                # 11 -> 5
        x = x.view(x.size(0), 32 * 5 * 5)
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x

23) Create the CNN, loss, and optimizer

This resets training for the new model architecture.

In [37]:
# Create a fresh CNN model and its training objects
model = Net().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

24) Sanity-check the CNN output shape

Just like before, the output should be [batch_size, 10].

In [38]:
batch = next(iter(train_loader))

data = batch["pixel_values"].to(device)
output = model(data)

print(output.shape)
torch.Size([64, 10])

25) Try one CNN training step

Again, it is worth proving the new model can do one clean optimizer step before training for real.

In [39]:
batch = next(iter(train_loader))

data = batch["pixel_values"].to(device)
target = batch["label"].to(device)

optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)

print("loss:", loss.item())

loss.backward()
optimizer.step()
loss: 2.306305170059204

26) Train the CNN for one epoch

This gives a quick first comparison against the fully connected network.

In [40]:
model.train()
running_loss = 0.0

for batch in train_loader:
    data = batch["pixel_values"].to(device)
    target = batch["label"].to(device)

    optimizer.zero_grad()
    output = model(data)
    loss = criterion(output, target)
    loss.backward()
    optimizer.step()

    running_loss += loss.item()

print("average loss:", running_loss / len(train_loader))
average loss: 0.25828741326816934

27) Evaluate the CNN

At this point, the CNN should already be competitive or better.

In [41]:
model.eval()
correct = 0
total = 0

with torch.no_grad():
    for batch in test_loader:
        data = batch["pixel_values"].to(device)
        target = batch["label"].to(device)

        output = model(data)
        pred = output.argmax(dim=1)

        correct += (pred == target).sum().item()
        total += target.size(0)

print("accuracy:", correct / total)
accuracy: 0.9691

28) Train the CNN for a few more epochs

This is the main comparison point with the earlier fully connected model.

In [42]:
for epoch in range(3):
    model.train()
    running_loss = 0.0

    for batch in train_loader:
        data = batch["pixel_values"].to(device)
        target = batch["label"].to(device)

        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

    print(f"epoch {epoch + 1} loss: {running_loss / len(train_loader):.4f}")

    model.eval()
    correct = 0
    total = 0

    with torch.no_grad():
        for batch in test_loader:
            data = batch["pixel_values"].to(device)
            target = batch["label"].to(device)

            output = model(data)
            pred = output.argmax(dim=1)

            correct += (pred == target).sum().item()
            total += target.size(0)

    print(f"epoch {epoch + 1} accuracy: {correct / total:.4f}")
epoch 1 loss: 0.0821
epoch 1 accuracy: 0.9776
epoch 2 loss: 0.0568
epoch 2 accuracy: 0.9858
epoch 3 loss: 0.0452
epoch 3 accuracy: 0.9848

29) Re-test the tricky rotated/slanted 9

The fully connected model was confidently wrong on this sample. The CNN should handle it better.

In [43]:
sample = dataset["test"][6166]
image = sample["pixel_values"].unsqueeze(0).to(device)
label = sample["label"].item()

model.eval()
with torch.no_grad():
    output = model(image)
    probs = torch.softmax(output, dim=1).squeeze()
    pred = probs.argmax().item()

print("predicted:", pred)
print("actual:", label)
print("confidence:", probs[pred].item())

top_probs, top_indices = torch.topk(probs, 3)
for rank in range(3):
    print(f"choice {rank+1}: digit {top_indices[rank].item()} prob {top_probs[rank].item():.4f}")
predicted: 9
actual: 9
confidence: 0.8512046337127686
choice 1: digit 9 prob 0.8512
choice 2: digit 3 prob 0.1480
choice 3: digit 8 prob 0.0003

30) Save the CNN model

This stores the better image model for later reuse.

In [44]:
# Save the trained CNN weights
torch.save(model.state_dict(), "mnist_cnn.pt")

Notes

Main takeaway from this notebook:

  • a fully connected network can learn MNIST reasonably well
  • a CNN usually performs better because it preserves image structure
  • inspecting mistakes is often more informative than only looking at final accuracy
In [ ]: