Learning AI: MNIST in PyTorch¶
This notebook walks through a complete beginner workflow using the Hugging Face MNIST dataset and CPU-only PyTorch.
The flow is:
- import libraries
- load the dataset
- preprocess images into tensors
- build DataLoaders
- define and train a simple fully connected network
- save and reload that model
- inspect predictions and errors
- replace it with a CNN and compare results
# Import the libraries used throughout the notebook
from datasets import load_dataset
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms
import numpy as np
from torchvision.transforms import ToTensor, Normalize, Compose
import matplotlib.pyplot as plt
1) Load the dataset¶
load_dataset("mnist") returns a DatasetDict with train and test splits.
Each example has:
image: a PIL imagelabel: the correct digit (0-9)
# Download/load the MNIST dataset from Hugging Face
dataset = load_dataset("mnist")
2) Preprocess the images¶
ToTensor() converts each PIL image into a PyTorch tensor with shape [1, 28, 28] and scales pixel values to [0, 1].
A new field called pixel_values is added so the original image field is still available if needed.
to_tensor = ToTensor()
def preprocess(example):
example["pixel_values"] = to_tensor(example["image"])
return example
dataset = dataset.map(preprocess)
3) Tell Hugging Face to return PyTorch tensors¶
After set_format(...), pixel_values and label come back as PyTorch tensors, which makes them work cleanly with a DataLoader.
# Make Hugging Face return torch tensors for these columns
dataset.set_format(type="torch", columns=["pixel_values", "label"])
4) Build the DataLoaders¶
train_loadershuffles training datatest_loaderkeeps test order stable
The batch size controls how many images the model sees before each optimizer step.
# Force CPU
train_loader = torch.utils.data.DataLoader(
dataset["train"],
batch_size=64,
shuffle=True
)
test_loader = torch.utils.data.DataLoader(
dataset["test"],
batch_size=1000,
shuffle=False
)
5) Sanity-check one batch¶
This confirms that batching worked and that the shapes are what the model expects:
- images:
[batch, channels, height, width] - labels:
[batch]
batch = next(iter(train_loader))
print(batch["pixel_values"].shape)
print(batch["label"].shape)
6) Define the first model: a fully connected network¶
This model flattens each 28x28 image into a length-784 vector, then runs it through two linear layers.
This is a good starter model because it is simple, but it does not understand image structure very well.
# Define a simple fully connected baseline model
device = torch.device("cpu")
class Net(nn.Module):
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(28 * 28, 128)
self.fc2 = nn.Linear(128, 10)
def forward(self, x):
x = x.view(x.size(0), 28 * 28)
x = torch.relu(self.fc1(x))
x = self.fc2(x)
return x
model = Net().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
7) Check the model output shape¶
A classifier for MNIST should return one score per digit class, so the expected shape is [batch_size, 10].
batch = next(iter(train_loader))
output = model(batch["pixel_values"].to(device))
print(output.shape)
8) Try one training step¶
Before doing a full epoch, it is useful to prove that:
- the forward pass works
- the loss computes
- backpropagation works
- the optimizer can update the model
batch = next(iter(train_loader))
data = batch["pixel_values"].to(device)
target = batch["label"].to(device)
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
print("loss:", loss.item())
loss.backward()
optimizer.step()
9) Train the fully connected model for one epoch¶
This runs through the full training set once and reports the average training loss.
model.train()
running_loss = 0.0
for batch in train_loader:
data = batch["pixel_values"].to(device)
target = batch["label"].to(device)
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
loss.backward()
optimizer.step()
running_loss += loss.item()
print("average loss:", running_loss / len(train_loader))
10) Evaluate the fully connected model¶
Evaluation turns off gradient tracking and measures classification accuracy on the test set.
model.eval()
correct = 0
total = 0
with torch.no_grad():
for batch in test_loader:
data = batch["pixel_values"].to(device)
target = batch["label"].to(device)
output = model(data)
pred = output.argmax(dim=1)
correct += (pred == target).sum().item()
total += target.size(0)
print("accuracy:", correct / total)
11) Continue training for a few more epochs¶
This shows how loss usually falls while test accuracy improves over time.
for epoch in range(3):
model.train()
running_loss = 0.0
for batch in train_loader:
data = batch["pixel_values"].to(device)
target = batch["label"].to(device)
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
loss.backward()
optimizer.step()
running_loss += loss.item()
print(f"epoch {epoch + 1} loss: {running_loss / len(train_loader):.4f}")
model.eval()
correct = 0
total = 0
with torch.no_grad():
for batch in test_loader:
data = batch["pixel_values"].to(device)
target = batch["label"].to(device)
output = model(data)
pred = output.argmax(dim=1)
correct += (pred == target).sum().item()
total += target.size(0)
print(f"epoch {epoch + 1} accuracy: {correct / total:.4f}")
12) Save the fully connected model¶
state_dict() stores the learned weights so the model can be reloaded later.
torch.save(model.state_dict(), "mnist_fc.pt")
13) Reload the saved fully connected model¶
A fresh model instance is created, then the saved weights are loaded into it.
loaded_model = Net().to(device)
loaded_model.load_state_dict(torch.load("mnist_fc.pt", map_location=device))
loaded_model.eval()
14) Verify the reloaded model¶
This should match the earlier test accuracy if save/load worked correctly.
loaded_model.eval()
correct = 0
total = 0
with torch.no_grad():
for batch in test_loader:
data = batch["pixel_values"].to(device)
target = batch["label"].to(device)
output = loaded_model(data)
pred = output.argmax(dim=1)
correct += (pred == target).sum().item()
total += target.size(0)
print("loaded model accuracy:", correct / total)
15) Make a single prediction¶
This is the simplest inference example: run one test image through the model and compare prediction vs. label.
sample = dataset["test"][0]
image = sample["pixel_values"].unsqueeze(0).to(device)
label = sample["label"].item()
loaded_model.eval()
with torch.no_grad():
output = loaded_model(image)
pred = output.argmax(dim=1).item()
print("predicted:", pred)
print("actual:", label)
16) Display that sample¶
Showing the image makes it easier to connect model output to what the digit actually looks like.
sample = dataset["test"][0]
image = sample["pixel_values"].squeeze().numpy()
label = sample["label"].item()
loaded_model.eval()
with torch.no_grad():
output = loaded_model(sample["pixel_values"].unsqueeze(0).to(device))
pred = output.argmax(dim=1).item()
plt.imshow(image, cmap="gray")
plt.title(f"predicted: {pred}, actual: {label}")
plt.axis("off")
plt.show()
17) Show several predictions at once¶
A small prediction grid is a fast way to spot correct predictions and obvious mistakes.
loaded_model.eval()
fig, axes = plt.subplots(2, 5, figsize=(10, 4))
for i, ax in enumerate(axes.flat):
sample = dataset["test"][i]
image = sample["pixel_values"]
label = sample["label"].item()
with torch.no_grad():
output = loaded_model(image.unsqueeze(0).to(device))
pred = output.argmax(dim=1).item()
ax.imshow(image.squeeze().numpy(), cmap="gray")
ax.set_title(f"p:{pred} a:{label}")
ax.axis("off")
plt.tight_layout()
plt.show()
18) Find the first misclassified example¶
Instead of only looking at successful predictions, this cell searches the test set until it finds a mistake.
loaded_model.eval()
for i in range(len(dataset["test"])):
sample = dataset["test"][i]
image = sample["pixel_values"]
label = sample["label"].item()
with torch.no_grad():
output = loaded_model(image.unsqueeze(0).to(device))
pred = output.argmax(dim=1).item()
if pred != label:
print("index:", i, "pred:", pred, "actual:", label)
import matplotlib.pyplot as plt
plt.imshow(image.squeeze().numpy(), cmap="gray")
plt.title(f"predicted: {pred}, actual: {label}")
plt.axis("off")
plt.show()
break
19) Inspect the model's top guesses for that mistake¶
Looking at the top probabilities helps distinguish:
- confidently wrong predictions
- uncertain predictions where several classes look plausible
sample = dataset["test"][i]
image = sample["pixel_values"]
label = sample["label"].item()
loaded_model.eval()
with torch.no_grad():
output = loaded_model(image.unsqueeze(0).to(device))
probs = torch.softmax(output, dim=1).squeeze()
top_probs, top_indices = torch.topk(probs, 3)
print("actual:", label)
for rank in range(3):
print(f"choice {rank+1}: digit {top_indices[rank].item()} prob {top_probs[rank].item():.4f}")
20) Find the most confident wrong prediction¶
This is especially interesting because it reveals where the model is overconfident.
loaded_model.eval()
worst_i = None
worst_pred = None
worst_label = None
worst_conf = -1.0
with torch.no_grad():
for j in range(len(dataset["test"])):
sample = dataset["test"][j]
image = sample["pixel_values"]
label = sample["label"].item()
output = loaded_model(image.unsqueeze(0).to(device))
probs = torch.softmax(output, dim=1).squeeze()
pred = probs.argmax().item()
conf = probs[pred].item()
if pred != label and conf > worst_conf:
worst_i = j
worst_pred = pred
worst_label = label
worst_conf = conf
print("index:", worst_i)
print("predicted:", worst_pred)
print("actual:", worst_label)
print("confidence:", worst_conf)
21) Display that confidently wrong example¶
This particular case motivated the switch from a fully connected model to a CNN.
sample = dataset["test"][6166]
image = sample["pixel_values"].squeeze().numpy()
import matplotlib.pyplot as plt
plt.imshow(image, cmap="gray")
plt.title("predicted: 3, actual: 9, conf: 0.9987")
plt.axis("off")
plt.show()
22) Replace the model with a CNN¶
A convolutional neural network keeps spatial structure and usually works much better on images than a fully connected network.
This architecture uses:
- two convolution layers
- max pooling
- two fully connected layers at the end
# Define a convolutional neural network for better image handling
class Net(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(1, 16, kernel_size=3)
self.conv2 = nn.Conv2d(16, 32, kernel_size=3)
self.pool = nn.MaxPool2d(2, 2)
self.fc1 = nn.Linear(32 * 5 * 5, 128)
self.fc2 = nn.Linear(128, 10)
def forward(self, x):
x = torch.relu(self.conv1(x)) # 28 -> 26
x = self.pool(x) # 26 -> 13
x = torch.relu(self.conv2(x)) # 13 -> 11
x = self.pool(x) # 11 -> 5
x = x.view(x.size(0), 32 * 5 * 5)
x = torch.relu(self.fc1(x))
x = self.fc2(x)
return x
23) Create the CNN, loss, and optimizer¶
This resets training for the new model architecture.
# Create a fresh CNN model and its training objects
model = Net().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
24) Sanity-check the CNN output shape¶
Just like before, the output should be [batch_size, 10].
batch = next(iter(train_loader))
data = batch["pixel_values"].to(device)
output = model(data)
print(output.shape)
25) Try one CNN training step¶
Again, it is worth proving the new model can do one clean optimizer step before training for real.
batch = next(iter(train_loader))
data = batch["pixel_values"].to(device)
target = batch["label"].to(device)
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
print("loss:", loss.item())
loss.backward()
optimizer.step()
26) Train the CNN for one epoch¶
This gives a quick first comparison against the fully connected network.
model.train()
running_loss = 0.0
for batch in train_loader:
data = batch["pixel_values"].to(device)
target = batch["label"].to(device)
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
loss.backward()
optimizer.step()
running_loss += loss.item()
print("average loss:", running_loss / len(train_loader))
27) Evaluate the CNN¶
At this point, the CNN should already be competitive or better.
model.eval()
correct = 0
total = 0
with torch.no_grad():
for batch in test_loader:
data = batch["pixel_values"].to(device)
target = batch["label"].to(device)
output = model(data)
pred = output.argmax(dim=1)
correct += (pred == target).sum().item()
total += target.size(0)
print("accuracy:", correct / total)
28) Train the CNN for a few more epochs¶
This is the main comparison point with the earlier fully connected model.
for epoch in range(3):
model.train()
running_loss = 0.0
for batch in train_loader:
data = batch["pixel_values"].to(device)
target = batch["label"].to(device)
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
loss.backward()
optimizer.step()
running_loss += loss.item()
print(f"epoch {epoch + 1} loss: {running_loss / len(train_loader):.4f}")
model.eval()
correct = 0
total = 0
with torch.no_grad():
for batch in test_loader:
data = batch["pixel_values"].to(device)
target = batch["label"].to(device)
output = model(data)
pred = output.argmax(dim=1)
correct += (pred == target).sum().item()
total += target.size(0)
print(f"epoch {epoch + 1} accuracy: {correct / total:.4f}")
29) Re-test the tricky rotated/slanted 9¶
The fully connected model was confidently wrong on this sample. The CNN should handle it better.
sample = dataset["test"][6166]
image = sample["pixel_values"].unsqueeze(0).to(device)
label = sample["label"].item()
model.eval()
with torch.no_grad():
output = model(image)
probs = torch.softmax(output, dim=1).squeeze()
pred = probs.argmax().item()
print("predicted:", pred)
print("actual:", label)
print("confidence:", probs[pred].item())
top_probs, top_indices = torch.topk(probs, 3)
for rank in range(3):
print(f"choice {rank+1}: digit {top_indices[rank].item()} prob {top_probs[rank].item():.4f}")
30) Save the CNN model¶
This stores the better image model for later reuse.
# Save the trained CNN weights
torch.save(model.state_dict(), "mnist_cnn.pt")
Notes¶
Main takeaway from this notebook:
- a fully connected network can learn MNIST reasonably well
- a CNN usually performs better because it preserves image structure
- inspecting mistakes is often more informative than only looking at final accuracy