Training a Neural Network With a Few, Clean Lines of Code With PyTorch
This article explains how to train a neural network with only a few, clean lines of code that is reusable, maintainable and easy to understand.
Less code usually results in readable code that is easy to understand and easy to maintain. The Python programming language, that has become very popular in the machine learning community, allows you to achieve great results with less code compared to other programming languages.
PyTorch is a popular deep learning framework for Python that has a clean API and allows you to write code that really feels like Python. Due to that it’s really fun to create models and perform machine learning experiments with PyTorch in Python.
In this article I’m going to show you the basic steps that you need to train a simple classifier that recognizes handwritten digits. You will see how to
- load the MNIST dataset (the ‘Hello World’ dataset for machine learning) with PyTorch’s data loader
- declare the architecture of our model
- select an optimizer
- implement the training loop
- determine the accuracy of the trained model
I want to keep everything as simple as possible. Therefore, I don’t cover things like overfitting, data preprocessing or different metrics to evaluate the performance of our classifier. We will just implement the basic building blocks that are required for training a classifier and that can be easily reused in other machine learning experiments.
So let’s start writing some code.
The first thing we need to do is to import the necessary packages. As we’re using PyTorch we need to import the packages torch
and torchvision
.
import torch
import torchvision as tv
Loading the Data
Now, we can load our training and validation dataset via torchvision.
t = tv.transforms.ToTensor()
mnist_training = tv.datasets.MNIST(
root='/tmp/mnist',
train=True,
download=True,
transform=t
)
mnist_val = tv.datasets.MNIST(
root='/tmp/mnist',
train=False,
download=True,
transform=t
)
First, we create an instance of ToTensor()
which we use to convert the images we get from the datasets
package into tensors. We need this step as all PyTorch functions operate on tensors. If you don’t know tensors, these are basically just a fancy name for multidimensional arrays. A tensor has a rank. For instance, a tensor with rank 0 is a scalar, a tensor with rank 1 is a vector, rank 2 is a matrix and so on.
Then, we load our training and validation dataset. With root
we can specify the directory which is used to store the dataset on disc. If we set train
to true, the training set is loaded. Otherwise the validation set is loaded. If we set download
to true, PyTorch downloads the datasets and stores them into the directory specified via root
. Finally, we can specify the transformation that should be applied to each example of the training and validation dataset. In our case it’s just ToTensor()
.
Specify the Architecture of Our Model
Next, we specify the architecture of our model.
model = torch.nn.Sequential(
torch.nn.Linear(28*28, 128),
torch.nn.ReLU(),
torch.nn.Linear(128, 10)
)
Here, we use a very simple neural network architecture. The network has an input layer with 784 (28*28) neurons in the input layer, 128 neurons in the hidden layer and 10 neurons in the output layer (for each possible label one neuron). We’re using ReLU as the activation function.
Selecting an Optimizer and Loss Function
Next, we specify the optimizer and the loss function.
opt = torch.optim.Adam(params=model.parameters(), lr=0.01)
loss_fn = torch.nn.CrossEntropyLoss()
We are using the Adam optimizer. With the first parameter we specify the parameters of our model that the optimizer needs to optimize. With the second parameter lr
we specify the learning rate.
In the second line we select the CrossEntropyLoss
as the loss function (another word for the loss function that is in common use is criterion). This function takes the unnormalized (N x 10)-dimensional output of our output layer (N is the number of samples of our batch) and computes the loss between the network’s output and the target labels. The target labels are represented as an N-dimensional vector (or more concretely a tensor of rank 1) that contains the class indices of the input samples. As you can see the CrossEntropyLoss is a very convenient function. First, we don’t need a normalization layer such as softmax at the end of our network. Second, we don’t have to convert between different representations for labels. Our network outputs a 10 dimensional vector of scores and the target labels are provided as a vector of class indices (an integer between 0 and 9).
Next we create a data loader for the training dataset.
loader = torch.utils.data.DataLoader(
mnist_training,
batch_size=500,
shuffle=True
)
The data loader is used to retrieve samples from a dataset. We can use a data loader to easily iterate over batches of samples. Here, we create a loader that returns 500 samples from the training dataset in each iteration. If we set shuffle
to true, the samples will be shuffled in the batch.
Training a Machine Learning Model
Now, we have everything we need to train our model.
for epoch in range(10):
for imgs, labels in loader:
n = len(imgs)
imgs = imgs.view(n, -1)
predictions = model(imgs)
loss = loss_fn(predictions, labels)
opt.zero_grad()
loss.backward()
opt.step()
print(f"Epoch: {epoch}, Loss: {float(loss)}")
We use 10 epochs to train our network (line 1). In each epoch we iterate of the loader to get 500 images with their labels in each iteration (line 2). The variable imgs
is a tensor of shape (500, 1, 28, 28). The variable labels
is a tensor of rank 1 with 500 class indices.
In line 3 we save the number of images of the current batch in the variable n
. In line 4 we reshape the imgs
tensor from the shape (n, 1, 28, 28) into a tensor of shape (n, 784). In line 5 we use our model to predict the labels of all images of our current batch. Then, in line 6, we compute the loss between these predictions and the ground truths. The tensor predictions
is a tensor of shape (n, 10) and labels
is a tensor of rank 1 that contains the class indices. In line 7 to 9 we reset the gradients for all parameters of our network, we compute the gradient and update the model’s parameters.
We also print the loss after each epoch so that we can verify that the network gets better (i.e. the loss decreases) after each epoch.
Determine the Accuracy
Now, that we have trained the network we can determine the accuracy with which the model recognizes handwritten digits.
First, we need to get the data from our validation dataset.
n = 10000
loader = torch.utils.data.DataLoader(mnist_val, batch_size=n)
images, labels = iter(loader).next()
Our validation dataset mnist_val
contains 10000 images. To get all these images we use a DataLoader and set the batch_size
to 10000. Then, we can get the data by just creating an iterator from the data loader and calling next()
on that iterator to get the first element.
The result is a tuple. The first element of this tuple is a tensor of shape (10000, 1, 28, 28). The second element is a tensor of rank 1 which contains the class indices for the images.
Now, we can use our model to predict the labels of all images.
predictions = model(images.view(n, -1))
Before we can provide the data to our model we need to reshape it (similar to what we’ve already done in the training loop). The input tensor for our model needs to have the shape (n, 784). When the data has the correct shape, we can use it as input for our model.
The result is a tensor of the shape (10000, 10). For each image of our validation set, this tensor stores the scores for each of the ten possible labels. The higher the score for a label, the more likely it is that the sample belongs to the corresponding label.
Usually, a sample is assigned to the label that has the highest score. We can easily determine that label via the argmax()
method from a tensor as this method returns the position of the maximum.
predicted_labels = predictions.argmax(dim=1)
We are interested in the maximum for dimension 1 as the scores are stored along this dimension for each sample. The result is a tensor of rank 1 which now stores the predicted class indices instead of the scores.
Now, we can compare the predicted class indices with the ground truth labels to compute the accuracy for the validation dataset. The accuracy is defined as the fraction of samples that have been predicted correctly, i.e. the number of correctly predicted samples divided by the total number of samples.
The trick to get this metric is to perform an elementwise comparison of the predicted labels and the ground truth labels.
torch.sum(predicted_labels == labels) / n
We compare predicted_labels
with labels
for equality and get a vector of booleans. An element in this vector is true if two elements at the same position are equal. Otherwise, an element is false. Then, we use sum
to count the number of elements that are true and divide that number by n.
If we perform all these steps, we should achieve an accuracy of roughly 97%.
The full code is also available on GitHub.
Conclusion
As a software engineer I love clean code. Sometimes I’m sitting in front of some small piece of code for half an hour or even longer just to make it more elegant and beautiful.
Due to that I love PyTorch. PyTorch is a great deep learning framework that allows you to write clean code that is easy to understand and that feels like Python. We’ve seen that with PyTorch it’s possible to produce expressive code to create and train state-of-the-art machine learning models with just a few lines of codes.
The building blocks we have used in this simple demo can be reused in many machine learning projects and I hope they will also be helpful for you during your machine learning journey.