Hello folks!

Recently I trained a handwritten digit classifier using the MNIST dataset from scratch, and it was an eye-opening experience for me. What looked like magic to me before now is just a few mathematical concepts applied together. I was very excited about how simple the whole process was, and I want to share it with everyone who still wonders what kind of magic is happening inside.

All you need to know is a bit of python and a few concepts from high-school math. No previous machine learning background is necessary.

This post is inspired by chapter 4 of the fast.ai book (I highly recommend it if you’re getting started with deep learning), where we tried to build a classifier that can recognize 3’s and 7’s from the MNIST dataset. This time, however, we will train a classifier to recognize all ten digits of the dataset. I will also try to avoid using “magical” high-level components as much as possible.

Let’s go!

We will use PyTorch and fast.ai libraries for a few useful utilities.

First, we need to import a few functions from the fastai library:

from fastai.vision.all import *
from fastbook import *

matplotlib.rc('image', cmap='Greys')

Downloading the dataset

Let’s download the dataset first. Fast.ai has a convenient function to quickly obtain the dataset we’re going to use for training:

path = untar_data(URLs.MNIST)
(#2) [Path('/home/nm/.fastai/data/mnist_png/training'),Path('/home/nm/.fastai/data/mnist_png/testing')]

The dataset has two folders inside it, training and testing. training should be used for the model training and validation, and testing used to compare the accuracy between different models.

A small note: I will do a shortcut here and use the testing dataset to validate my model. That is something you shouldn’t do in production when you want to evaluate the model’s performance. Instead, you should split the training set into training/validation parts. For our toy problem, however, and for the sake of simplicity, I will use the testing set for validation.

Let’s peek inside the training folder:

(#10) [

What we have here is ten folders, one for each digit. Every folder contains thousands of images of size 28x28 representing some digit.

You might say: “Well, that’s just a bunch of images compressed with a PNG algorithm. How can we do any math operations on them?”. You’re correct, plain image files are not very useful. To get started, we will first transform the images into a tensor.

“Wait, what is a tensor?” you ask, and I’d say it a fancy name for a multi-dimensional array. The number of dimensions of this array is also called a rank of a tensor.

For example:

A zero-dimensional array (it is also called scalar) 1 is a tensor of rank 0.

1-d array aka vector [1,2,3] is a tensor of rank 1.

2-d arrray aka matrix


is a tensor of rank 2, and so on.

We want to convert the images to tensors because then we can perform math operations on them.

Now we have an image, which is a 2-d array of pixels with pixel intensity values ranging from 0 to 255. Next, we will represent it as a tensor of a rank 2, and scale the pixel values between 0 and 1. Scaling the pixels between 0 and 1 will give us a convenient abstraction for some of the actions we will do next.

Once we created a tensor from one image, we will convert other images in our dataset to tensors as well. Then, we will stack the image tensors together into a single tensor of rank 3. I’ll explain why we do this in a bit.

Let’s write a function for this:

def image_path_to_tensor(path):
    images_paths = path.ls()
    tensors = [tensor(Image.open(p)) for p in images_paths]
    return torch.stack(tensors).float()/255

Having defined the function, let’s use it to transform all the images in 10 folders that we have:

stacked_tensors = [image_path_to_tensor(image_folder) for image_folder in (path/'training').ls().sorted()]

The result is ten tensors that represent the training set of each digit. Let’s look at one digit and check the shape of the tensor:

torch.Size([5842, 28, 28])

The first dimension of the tensor corresponds to an image file, second and third dimensions are the height and the width of the image in pixels.

Because we had 5824 images for digit 4, and every image has the size of 28x28 pixels, therefore the resulting shape of the tensor is 5842 by 28 by 28. That’s a tensor of rank 3.

Building a baseline: pixel difference

Before we start building any machine learning model it is generally a good idea to come up with a simple baseline first. Sometimes the problem at hand can be solved quite well without involving any machine learning at all, and by creating a baseline solution we will be able to justify usage of a machine learning model if the baseline performance is not good enough.

One obvious choice would be a random baseline, where each digit is chosen at random, but let’s try a little bit harder and build a baseline that relies on pixel similarity.

The idea is the following: we will average the image tensors along the first axis, and it will be our “pretty average digit”. Then we will compare how similar each pixel of the image we try to classify to each pixel of every “average” digit, and our final guess will be the digit with the smallest error.

To start, let’s compute our “average” digit images:

image_means = [None] * 10
for digit, imgs_tensor in enumerate(stacked_tensors):
    mean_tensor = imgs_tensor.mean(0)
    image_means[digit] = mean_tensor

For example, this is how an “average” 4 looks like:


Now we need a function that can compute the error. One possible option is to subtract the pixels of two images, square them to make them between 0 and 1, and then take the mean of the result:

def calculate_error(image, label):
    ideal_tensor = image_means[label]
    return ((image - ideal_tensor)**2).mean()

Let’s try it out on a digit 3:


The average digit of 3 looks like this:


calculate_error(stacked_tensors[3][0], 3)

This function is also called mean squared error.

Now let’s create a validation set. As I already mentioned above, using the testing set as a validation set is a shortcut for simplicity, and you should not do the same in production.

valid_stacked = [image_path_to_tensor(image_folder) for image_folder in (path/'testing').ls().sorted()]

We’re able to compute an error for a single image. Our job now is to compute the error for all images in the validation set. Unfortunately, we can’t simply write a loop and do this.

not simply meme

The main issue is python loops are unlikely to be vectorized on the GPU. GPU stands for Graphics Processing Unit, and you most likely have one in your computer. GPUs are important for us because they can perform math operations with crazy level of parallelism. But if we will take a single image, put it on the GPU, calculate the error and then remove it from the GPU, it will take us quite a while to calculate the error for the whole validation set.

Instead of using a loop and processing images one by one, we will grab a tensor with images, put it on the GPU, and then calculate the error for all images in parallel. GPUs are super fast number-crunching machines, and it will take no time for our GPU to do that.

But how can we do this? Welcome to broadcasting. Broadcasting is a technique that allows PyTorch to perform operations on tensors with different shapes in a way as they were same-shape tensors. Let me explain it with a simple example.

If we have two tensors:

a = [[1,2],


b = [1,1]

then if we write something like c = a + b, we will get the next tensor as a result:

c = [[2,3],

Tensor b was added to each element of the tensor a even though they have different ranks. Simply speaking, instead of looping through all elements of the vector a, we told PyTorch to do this in a declarative fashion.

With broadcasting in mind, let’s define a few functions to measure the distances between our images.

Naïve approach with subtracting one image pixes from another and taking the mean will not work very well. It is possible to have drastically different images where dark and bright pixels will compensate each other, and the mean will be close to zero. To solve this problem, we should make the difference positive and then take the mean.

For example, we can calculate a mean absolute error (MAE) and a mean squared error (MSE). The difference between them is that MSE would result in a higher error when images are very different.

def mnist_distance_mae(a,b): return (a-b).abs().mean((-1,-2))
def mnist_distance_mse(a,b): return torch.square((a-b)).mean((-1,-2))

The functions above subtract pixels of two images, then take the absolute/squared value and computing the mean. (-1,-2) means that we want to compute the mean along the last and last-1 axis of our tensor, which corresponds to 28x28 image.

When I was implementing this code, broadcasting led to some very nasty bugs. For example, I had a situation where my model was not able to make any progress in training. Another time the training was much slower than I expected, and I did not know why. As it turned out, there was a bug in the loss function because of implicit broadcasting, and it took me a lot of time to localize the problem. It is very important to check that the shape of the returned tensor on every stage corresponds to what we expect.

Let’s see what is the shape of the result of our distance function:

mnist_distance_mae(valid_stacked[3], image_means[3]).shape

It is a tensor of size 1010, with one distance per one validation image. That seems to be correct, and we can continue.

mnist_distance_mse(valid_stacked[3], image_means[3])
tensor([0.0575, 0.0516, 0.0542,  ..., 0.0480, 0.0642, 0.0429])

Given the distance functions, now we can write a function which will tell us if our digit prediction is correct:

def analyze_digit(digit: str, candidate, distance_function):
    distances = torch.stack([distance_function(candidate, mean_tensor) for mean_tensor in image_means])
    lowest_distance = torch.argmin(distances, dim=0)
    return lowest_distance == digit
analyze_digit(4, stacked_tensors[3][113], mnist_distance_mae)

Let’s put broadcasting into work and do analyze all 3s at once:

analyze_digit(3, valid_stacked[3], mnist_distance_mae).float().mean()

Our accuracy is about 60%. It is better than random (it would be around 10), but we can do better. What is the mean accuracy for all digits that we have? We can use both of our distance functions and compare the results:

accuracy_mae = torch.stack([analyze_digit(digit, valid_stacked[digit], mnist_distance_mae).float().mean() for digit in range(10)])
tensor([0.8153, 0.9982, 0.4234, 0.6089, 0.6680, 0.3262, 0.7871, 0.7646, 0.4425, 0.7760])
accuracy_mse = torch.stack([analyze_digit(digit, valid_stacked[digit], mnist_distance_mse).float().mean() for digit in range(10)])
tensor([0.8959, 0.9621, 0.7568, 0.8059, 0.8259, 0.6861, 0.8633, 0.8327, 0.7372, 0.8067])

As we see, MSE gives us better results since it penalizes bigger differences more compared to smaller ones. Let’s look at the mean accuracies:


One observation is that MSE gives us better accuracy compared to MAE.

We got 82% accuracy without any learning at all! We can take it as our baseline, and our goal would be to train a machine learning model which can do this better. It is always a good idea to start with a simple solution and then move to a more sophisticated one if the observed performance is below expectations.

Machine Learning time!

Now, we’re finally going to do some machine learning.


First, we will concatenate our images for different digits into a single tensor, and then we will represent every 28x28 image as 1x784 vector. I’ll explain why we do it in a bit. This tensor is going to be the “input” for our model.

train_x = torch.cat([stacked_tensors[i] for i in range(10)]).view(-1, 28*28)

Next, let’s create another tensor containing the correct labels for every image in our previous tensor.

train_y = torch.cat([torch.stack([tensor(i)]*len(stacked_tensors[i])) for i in range(10)])
(torch.Size([60000, 784]), torch.Size([60000]))

Finally, we will combine them in a dataset. A dataset is just a simple pair of inputs and outputs to the model, in our case, it is the images (train_x) and the labels (train_y).

dataset = list(zip(train_x,train_y))

We can also look at what is inside. To do this, we want to convert 1x784 tensor into a 28x28 one. Method view of a tensor can help us with it:



And the label is also 5:


Now we perform the same operations for the validation set:

valid_x = torch.cat([valid_stacked[i] for i in range(10)]).view(-1, 28*28)
valid_y = torch.cat([torch.stack([tensor(i)]*len(valid_stacked[i])) for i in range(10)])
(torch.Size([10000, 784]), torch.Size([10000]))
valid_dset = list(zip(valid_x,valid_y))

Linear model

Before we continue, let’s quickly talk about what we have done and what we’re going to do next.

We have already collected the training and the validation datasets, and the dataset is just a collection of images and the corresponding labels. We also transformed every image into a vector.

What we need now is to define a function, which will take an image as an input and will produce the prediction as an output.

To be more specific, our function will accept a vector of size 784 as an input, where each element of this vector will correspond to a pixel of an image. Since our function does not care about the arrangement of input values, we reshaped our tensor earlier to make it more convenient for us to work with it. The produced result will be a vector of size 10, where each number will represent the probability of each digit.

In the code, it would be something like this:

    def predict(pixels: List[float]): List[float]

But what kind of function is capable of doing such a transformation? To our luck, such functions exist, and they are called neural networks!

It is proven that given enough parameters, a neural network with only one hidden layer is capable of approximating any function with an arbitrary level of precision. How cool is that! For more information, look for the universal approximation theorem.

We, however, will start with a degenerate case of a neural network called a linear function.

In pseudocode it looks like this:

def linear(x: float, weight: float, bias: float) -> float:
    return x * weight + bias

This function accepts an input, number x, a parameter called weight, and a parameter called bias. Then the input is multiplied by the weight and the bias added in the end. Sounds simple, right?

However, in our case, we have not a single but 784 input parameters. Let’s for now pretend that we’re only interested in predicting a single digit:

from typing import List
def linear(x: List[float], weights: List[float], bias: float) -> float:
    return (x * weights).sum() + bias

Here * stand for element-wise vector multiplication, meaning that if we have two vectors,

a = [1, 2, 3]
b = [1, 2, 3]

then a * b would be the vector [1*1, 2*2, 3*3]

The result of this function will be a likeliness that our input x corresponds to a label. But what if we want to predict more than one label?

That’s simple as well. First, we define a few sets of weights and biases, one per class we want to predict. Then we call the function several times to make a prediction for each class. The number returned by the functions represents neural net confidence in the predicted digit. Then we interpret the result with the highest number as the predicted class. Let’s say we only want to predict three digits, then what we want to do is next:

score_0 = linear(image, weights_0, bias_0)
score_1 = linear(image, weights_1, bias_1)
score_2 = linear(image, weights_2, bias_2)

Finally, we check which resulted in the highest score, and this is going to be the predicted class.

Let’s wrap it in a function:

def predict(image: List[float], weights: List[List[float]], biases: List[float]) -> List[float]:
    results = []
    for digit in range(3):
        likeliness = linear(image, weights[digit], biases[digit])
    return results

That’s it! Our simple linear model will be able to predict an image from pixels.

However, we have two problems:

  1. This function is going to be extremely slow since GPUs do not like loops, and the loop we defined earlier will be executed using plain python runtime.
  2. We don’t know where to get these magical weights and biases!

To deal with the first problem, we will utilize PyTorch tensors, which will give us free GPU parallelization.

As for the second problem, we need to learn about Stochastic Gradient Descent (SGD).

Everything is a Tensor

Since we want to use GPU for computation, we need to represent our weights and biases as tensors. Let’s write a helper function for this.

def init_params(size): return torch.randn(size).requires_grad_()

Method .requires_grad_() tells PyTorch to track the operations done with the tensor so that later we can ask PyTorch to compute gradients for us. More on that in a bit.

Remember the weights in the predict function? We wanted to have something like this List[List[float; 784]; 10] (The number next to the type is the size of the list). That is nothing else than a matrix of size 784 by 10.

If you’re wondering where we got these magic numbers, it comes from the following. The number 784 used because our input image has a size of 28x28, and later we unroll it into a 1x784 vector. Number 10 comes from the ten classes of digits we’re going to recognize.

Let’s initialize the weights randomly:

weights = init_params((28*28,10))

The same goes for biases:

biases = init_params(10)

Now we can try to calculate the predictions for a single image. Instead of looping for every digit and multiplying vectors, we can do everything in one single step using matrix-vector multiplication. And PyTorch will use the power of the GPU to do this operation as quickly as possible!

train_x[34000] @ weights + biases
tensor([  5.4221,  -1.5908,  17.0564,  -4.3119, -15.0097,  -8.3001,   0.5106,   2.3884,  -6.6653,  13.2021], grad_fn=<AddBackward0>)

The new operator @ above is a PyTorch operator for matrix multiplication. I will not get into details about what matrix multiplication is, but a nice and intuitive explanation of it can be found on http://matrixmultiplication.xyz/.

Let’s write a function that will represent our linear model

def linear_model(x_batch): 
    return x_batch @ weights + biases
predictions = linear_model(train_x[34000])
tensor([  5.4221,  -1.5908,  17.0564,  -4.3119, -15.0097,  -8.3001,   0.5106,   2.3884,  -6.6653,  13.2021], grad_fn=<AddBackward0>)

One interesting property of our matrix-powered function is that this function can process a batch of several images at once. Instead of providing a single vector as an input, we can provide a matrix containing a few images, then do the matrix multiplication, and then we will get the matrix with the predictions back.

With this in mind, we can calculate predictions for the whole dataset without writing a single loop:

predictions = linear_model(train_x)
torch.Size([60000, 10])

argmax will return the index of the tensor with the biggest value. Since we expect the model to return bigger numbers for the predictions it is more confident with, the number returned will be the predicted class. Let’s calculate the accuracy:

(predictions.argmax(dim=1) == train_y).float().mean().item()


The model initialized using random parameters made around 10% of predictions correct. That’s pretty much what we would expect for a random model since the probability of being correct with a random guess is 0.1.


So we solved the first problem, the prediction function is fast and can be run on a GPU. But we still have no idea where to get the proper (non-random) parameters.

Now, let’s talk about gradient descent (the “stochastic” part will be explained later).

As you might remember from a calculus course, a derivative of a function is another function that represents a rate of change of a function output with respect to its input. If we evaluate a derivative at some point where our function is defined, we can think of a result as a direction in which the function evaluated at that point will grow the fastest. If we know this, we can easily figure out in which direction we need to go in order to reach the minimum of the function.

In case our function operates on several input variables, instead of the derivative we will calculate a gradient, which can be imagined as an “advanced” version of the derivative capable of working with multi-variable functions. The gradient is similar to the derivative since it will give us the direction in which the function is going to grow the fastest.

Once we have the gradient, we can multiply it by -1 to get the direction in which the function will decline the fastest, then multiply the gradient by a small value and update the parameters. After we do this, the function we’re interested in will produce a smaller value given the updated parameters.

Using this information, we can describe the gradient descent process as following:

initialize function parameter randomly
while (function parameters are not good enough):
    calculate gradient
    parameters = parameters - gradient * step

The gradient descent is capable of optimizing any differentiable function, and we will utilize this for our training process.

Next, we need to define a so-called loss function, which will tell us how good or bad the performance of our linear model is when we change the parameters.

Once we define it, we will run the gradient descent method with the loss function and find the optimal parameters.

Loss Function

Let’s think about what we want from our loss function. We want it to take images, parameters, and labels as an input and produce a number that shows how “bad” our linear model is.

Taking this into account, our first implementation of the loss function will look like this:

def mnist_loss(predictions, targets, nr_classes=10):
    predictions = predictions.sigmoid()
    mask = F.one_hot(targets, nr_classes)
    errors = ((mask-predictions) ** 2)
    return errors.mean()

Let’s analyze it line by line:

predictions = predictions.sigmoid()

We apply a sigmoid function to each prediction.

def sigmoid(x): return 1/(1+torch.exp(-x))

The sigmoid function maps its input to a range between 0 and 1, and it looks like this:


Our idea is to subtract correct predictions from 1-s, so that when our model is correct and confident in the prediction, the error will be close to zero. With other predictions, we will not do anything, which will result in a low error of the model does not give us high confidence for incorrect labels.

mask = F.one_hot(targets, nr_classes)

Here we generate a mask, which will represent the correct labels. For example F.one_hot(tensor([0,2,1]), 3) will give us


Next, we will subtract our prediction from the mask. To deal with negative numbers we will take a squared value of the error, which will keep the function differentiable and additionally will have a nice bonus of punishing big errors more than small ones.

errors = ((mask-predictions) ** 2)

In the end, we will take a mean to reduce the error tensor to a single number:


Having defined a lost function, we want to calculate its gradient. However, we run into a problem: our loss function operates on predictions and labels and not on the model parameters.

To include the weights and biases into the gradient calculating, we will take a derivative of a new function which is a composition of the model and the loss function. The new function will look like this:

def composed(x_batch, y_batch):
    preds = linear_model(x_batch)
    loss = mnist_loss(preds, y_batch)

Now we need to calculate a gradient. If we do that by hand via calculating partial derivatives using a chain rule, the process would be a bit tedious because this function operates on thousands of parameters.

Fortunately for us, PyTorch can solve this quickly. What we need to do is to call .requires_grad() on a tensor in which gradients we’re interested before computing the function, and then call .backward() to tell PyTorch that we want the gradients.

Here is how we do it:

def calc_grad(x_batch, y_batch, model, loss_fn):
    preds = model(x_batch)
    loss = loss_fn(preds, y_batch)

Let’s talk about how we will feed the data to our gradient descent process. One option is to feed it one image at a time, but this will take an unreasonable amount of time to complete due to data moving overhead. Training on the whole dataset at once is not desirable either, since the dataset might be simply too big to fit on the GPU. What is commonly done is the following: the dataset is divided in batches, and then we calculate gradients using the whole batch. That’s where the word stochastic comes from, mainly because batched contain shuffled items, and the gradient descent is run on data randomly sampled from the dataset.

When we iterate through the dataset, we generally prefer to have diverse examples in our batches because this leads to better generalization. An easy way to achieve this is to shuffle the dataset. Fortunately for us, PyTorch provides a class called DataLoader, which does the shuffling and batch separation for us:

dl = DataLoader(dataset, batch_size=256,shuffle=True)
valid_dl = DataLoader(valid_dset, batch_size=256, shuffle=False)

Training loop

We can also write a function that runs SGD using the whole dataset (an iteration through all images called an epoch):

def train_epoch(model, loss_fn, lr, params):
    for xb,yb in dl:
        calc_grad(xb, yb, model, loss_fn)
        for p in params:
            p.data -= p.grad * lr

For every batch, we compute the gradients, multiply them by a small number called the learning rate, and subtract the result from our initial parameters. The learning rate should be low enough to keep the process stable and at the same time large enough so that our training does not last forever.

It is important to know how accurate our model is. We will start with the accuracy of a single batch:

def batch_accuracy(xb, yb):
    correct = (xb.argmax(axis=1) == yb).float().mean()
    return correct

Once we know how to measure the accuracy of the batch, we can measure the accuracy of the model in a single epoch:

def validate_epoch(model):
    accs = [batch_accuracy(model(xb), yb) for xb,yb in valid_dl]
    return round(torch.stack(accs).mean().item(), 4)

Let’s see what the accuracy of our randomly-initialized model is:


Something around 10% is what we expect.

Finally, we can train a single epoch:

accuracy = []
lr = 1.
# initializing random weights
weights = init_params((28*28,10))
# initializing random biases
biases = init_params(10)

params = weights,biases
# training an epoch
train_epoch(linear_model, mnist_loss, lr, params)
# validating results

Let’s train for 40 epochs:

for i in range(40):
    train_epoch(linear_model, mnist_loss, lr, params)



Note that the epoch index starts at 0.

The training starts fast, but it slows down as model accuracy increases. The problem is our model can be very certain about several different classes for the same image. However, this is not the task we’re trying to solve. The reason for this behavior lies in our loss function. Remember, the loss function first applies sigmoid to keep the function output between 0 and 1. This only scales the output, but what we want is that the model will select one label at the end.

Let’s replace the sigmoid function with a so-called softmax function. It is similar to sigmoid when we’re predicting a single class, but it will make the output probabilities sum to 1 if we have more than one class. Thus, only the relative difference of the predictions will be important, and a high degree of confidence in one class will automatically decrease confidence in other classes.

def mnist_loss_softmax(predictions, targets, nr_classes=10):
    predictions = torch.softmax(predictions, axis=1)
    mask = F.one_hot(targets, nr_classes)
    errors = ((mask-predictions) ** 2)
    return errors.mean()

Let’s train the model again:

accuracy = []

lr = 1.
# initializing random weights
weights = init_params((28*28,10))
# initializing random biases
biases = init_params(10)

params = weights,biases
# training an epoch
train_epoch(linear_model, mnist_loss_softmax, lr, params)
# validating results
for i in range(40):
    train_epoch(linear_model, mnist_loss_softmax, lr, params)


That helped, but our learning is still getting slower as the accuracy is getting higher, and we have not yet surpassed our baseline.

Let’s blame the loss function again! Now we can think of a problem we face as following: for our loss function the absolute difference between parameters that give 0.9 and 0.99 accuracies is small, about 0.1. But if we think about it, the second set would give us 10x more accurate results! To address this problem, we’re going to apply a negative logarithm fucntion to the results. It will rescale our outputs in a way so the SGD will find the right direction easier.

The negative logarithm function looks like this:


Such a loss function is then called a cross-entropy loss:

def cross_entropy_loss(preds, y):
    # apply softmax
    preds = torch.softmax(preds, axis=1)

    # get confidences for the correct class
    idx = len(preds)
    confidences = preds[range(idx), y]

    # calculate negative log likelihood and return its mean
    log_ll =  -torch.log(confidences)
    return log_ll.mean()

Let’s train the model again using the new shiny loss function:

accuracy = []
lr = 1.
# initializing random weights
weights = init_params((28*28,10))
# initializing random biases
biases = init_params(10)

params = weights,biases
# training an epoch
train_epoch(linear_model, cross_entropy_loss, lr, params)
# validating results

Wow, already after training for a single epoch we got 85% accuracy, outperforming our baseline. That’s a huge improvement! Let’s train for a few more epochs:

for i in range(40):
    train_epoch(linear_model, cross_entropy_loss, lr, params)


91.5% accuracy with a simple linear model! Isn’t it impressive?

But so far out model can hardly be called a neural network since it consists of a single linear layer. To make it a “real” neural net, we should add some-non-linearity. For example, we can add a sigmoid layer in-between two linear layers and then use function composition to apply them sequentially:

w1 = init_params((28*28,64))
b1 = init_params(64)
w2 = init_params((64,10))
b2 = init_params(10)

def simple_neural_net(xb): 
    res = xb@w1 + b1 # 1st linear layer
    res = torch.sigmoid(res) # 2nd non-linear layer
    res = res@w2 + b2 # 3rd linear layer
    return res

params = [w1,b1,w2,b2]
accuracy = []
lr = 1.

# training an epoch
train_epoch(simple_neural_net, cross_entropy_loss, lr, params)
# validating results
for i in range(40):
    train_epoch(simple_neural_net, cross_entropy_loss, lr, params)


Around 95% accuracy! That’s definitely better compared to the plain linear model.

As you can see, our new non-linear model is a drop-in replacement for the old one (as long as it has the same number of inputs and outputs), and the rest of the process stays exactly the same.

Finally, as a fun example let’s see how this task can be done using a neural network designed for image recognition:

dls = ImageDataLoaders.from_folder(path, train="training", valid="testing")
learn = cnn_learner(dls, resnet18, pretrained=False,
                    loss_func=F.cross_entropy, metrics=accuracy)
learn.fit_one_cycle(4, 0.1)

Wow, 99.4% only after four epochs! Those modern neural networks are powerful.

It is worth mentioning that SGD is not ideal. For example, there are no guarantees that it actually finds good parameters in a finite number of steps, let alone optimal ones.

Also, a neural network is a non-convex function, so there may be many local minimums where the SGD can get stuck. This is why simpler methods are often better suited to problems that can be solved with the indicated simpler methods.

Finally, ML researchers have made significant progress in the area of SGD, and neural networks and modern SGD + regularization have achieved state-of-the-art results for many problems.


We started with a simple pixel similarity baseline, which is capable of achieving 81% accuracy without any learning.

Then we moved on to a simple linear model, which we optimized using gradient descent. As we have seen, it is nothing more complicated than a combination of matrix multiplication and derivative calculation.

Finally, we replaced our linear model with a neural network and saw some accuracy improvements.

There is something cool in combining a neural net (which is capable of approximating any function giving the right parameters), and gradient descent (which is capable of finding good parameters for any differentiable function) together, isn’t it?


A big thank you goes to Felix Patzelt for reviewing this post.