Skip to content

PyTorch Workflow Fundamentals

This chapter will cover topics such as:

  • Data (preparing and loading)
  • Building a simple model
  • Fitting the model to data. i.e. training the model
  • Making predictions and evaluating a model (inference)
  • Saving and loading a model

PyTorch Workflow

Data

A very fundamental caveat around data in Machine Learning is to split the data into mainly three different groups:

  1. Training data
  2. Validation data
  3. Test data

The Training data is used to train the model to find the sought patterns and parameters. The Validation data is used to validate the model, and to make adjustments to it if needed. Finally the Test data is used to see if the model is ready to for use.

The typical splits is between 60-80% of your total data should be used for training. 10-20% for validation and 10-20% for testing. However, validation sets are not always used since the test set also works as a validation set.

Building Models

What type of model you are building depend entirely on what the purpose of the model is. However, in pytorch you always build a class that inherits from torch.nn.Module. This class contains all the necessary building blocks needed to build any Machine Learning model. See nn.Module for the documentation.

Below will is an example of a very simple model that will inherit from nn.Module and is a linear regression model.

1
2
3
4
5
6
class Model(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return 

In line 3 in the snippet above is the initialization class called from nn.Module which contains all necessary methods to compute. One important thing method seen in the snippet above is the forward() method, which overwrites nn.Modules one. This one is where we define what the model is supposed to do. In this case we have not yet defined this.

As a summary we have a few essential modules we will be using when building such a simple model as the one above.

  • torch.nn contains all of the building blocks for a neural network
  • torch.nn.Parameter contains what parameters that the model should try to learn. often a PyTorch layer which will be set for us.
  • torch.nn.Module is the base class for all neural networks, and if you inherit from it, you should overwrite the forward() method.
  • torch.optim contain various optimization algorithms.

A good place to look in order to see what various PyTorch modules do is to look into the PyTorch cheat sheet

Model Training

In order to make the model to predict data it has to be trained. It is done with the training data set. Down below we see essential blocks needed for training the model

Loss functions and Optimizers

  • Loss Function A function to measure how wrong the models predictions are
  • Optimizer Uses the loss of the model and adjust the models parameters to make the predictions better, by minimizing the loss function.

To see which built in loss functions ans optimizers that exist within PyTorch see Loss Functions and Optimizers. What is essential in all models is that these have to be picked to the developer, and which is the right one to use depend on the model that is being built.

Another important parameter that the developer have to choose is the hyperparameter lr which stand for learning rate. This tells the optimizer how much the parameter values should change in order to decrease the loss.

Training loop

The training loop is where the model is being trained in order to update the parameter values. There is a few essential steps in this loop that need to happen, and these are listed below:

Loop through the data and: 1. forward pass (moving the input data through the network) to give a output on the current parameters 2. calculate the loss with the loss function 3. optimizer zero grad, resets the gradients of the optimized tensor. 4. back propagation moves the data through the network backwards (calculate the loss gradients) 5. optimizer step, refining the parameters (gradient descent)

This is the most common order to do this in