Warning! Your browser does not support this Website: Try Google-Chrome or Firefox!

Early Stopping

by Elvira Siegel
(Published: Fri Aug 07, 2020)

In this article

we will introduce you to the concept of Early Stopping and its implementation including code samples.


The main goal of training a neural network is to acquire a model which is able to generalize optimally on new unseen data.

Unfortunately, most of neural network architectures often suffer from overfitting. Overfitting is the effect when the model fits the training data too well. That means, the model is not able to “make sense” of unseen data it receives.


1_overfitting

We can also see overfitting as an error which decreases on training data but increases on test data. Ideally, we want a small error while testing.


2_overfitting_

A typical overfitted graph will look something like in this picture right:


overfitt

Why is overfitting bad? Well, the central idea of a machine learning algorithm is to capture a dominant trend in the data and be able to recognize this trend on new data again. If we get a graph like the right most in the picture above, that would indicate that the model did not really learn anything but stupidly memorized training data without “understanding” it, hence the test error will be large (or in other words, the network accuracy will be small).


Data Split

First of all, we start off by splitting the whole data into three main subsets: training, validation and test set.

The biggest subset is the training data set. We use it to actually train the model which means we tweak weights and biases by the means of backpropagation. The model learns to “see the world” from this data.

Next the validation data set. This data subset is used to give an unbiased evaluation of a model after fitting it on the training dataset. We also tune model hyperparameters on the validation set.

With the validation set we are allowed to perform evaluation frequently. It works well because the model does not learn from the validation data set. It only “sees” it.

Sometimes the validation set is also referred to as development set.

Finally, the test data set. This subset is used to give an unbiased evaluation of the final model fitted (=trained) on the training data set.

We use the test data set only once when our model is completely trained, meaning we have already applied training and validation sets.



data_split

The splittig proportions may vary. As a rule of thumb we can use the following split ratios:

  • 70% train, 15% val, 15% test
  • 80% train, 10% val, 10% test
  • 60% train, 20% val, 20% test

Early Stopping and Validation Error

Commonly, a generalization error is approximated by a validation error (=also validation loss). Validation error is the average error computed on a validation set.

Validation can be used to detect when overfitting starts during supervised training.

Overfitting shows itself when the training error is small but the validation error (or, if we use only test/train data split, the test error) is large. We can utilize early stopping to improve the error rate.


meme_stop_it (1)

Early Stopping

Early stopping is one of the regularization techniques that helps us to avoid overfitting. With early stopping we monitor and record the validation error. In case the error stops decreasing for some amount of epochs in succession, the training stops.



early_stopping


Let’s see an implementation example with PyTorch.

Look at the class EarlyStopping below:


import numpy as np
import torch


class EarlyStopping:

    """Early stops the training if validation loss doesn't improve after a given patience."""

    def __init__(self, patience=7, verbose=False, delta=0, path='checkpoint.pt'):
        """
        Args:
            patience (int): How long to wait after last time validation loss improved.
                            Default: 7
            verbose (bool): If True, prints a message for each validation loss improvement. 
                            Default: False
            delta (float): Minimum change in the monitored quantity to qualify as an improvement.
                            Default: 0
            path (str): Path for the checkpoint to be saved to.
                            Default: 'checkpoint.pt'
        """
        self.patience = patience
        self.verbose = verbose
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        self.val_loss_min = np.Inf
        self.delta = delta
        self.path = path

    def __call__(self, val_loss, model):

        score = -val_loss

        if self.best_score is None:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
        elif score = self.patience:
                self.early_stop = True
        else:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
            self.counter = 0

    def save_checkpoint(self, val_loss, model):
        '''Saves model when validation loss decrease.'''
        if self.verbose:
            print(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}).  Saving model ...')
        torch.save(model.state_dict(), self.path)
        self.val_loss_min = val_loss   

link to the complete project by Bjarten


The class EarlyStopping tracks the validation error during the model training. A checkpoint of the model is saved each time the validation error decreases:

  torch.save(model.state_dict(), self.path)   

Patience

The patience parameter is important to early stopping. Patience here is the epoch number to wait for an error improvement (=error decrease). If after n epochs there is still no improvement in the error, the training will be terminated.

Let's look at an example from MNIST_Early_Stopping_example.ipynb how to use the EarlyStopping module class.


We import the class from the pytorchtools.py module:

  from pytorchtools import EarlyStopping   

To initialize an early_stopping object, we do:

  early_stopping = EarlyStopping(patience=patience, verbose=True)   

The early_stopping variable checks whether the validation error degraded. In case the error has degraded, it will create a checkpoint of the current model:

 early_stopping(valid_loss, model)
        
        if early_stopping.early_stop:
            print("Early stopping")
            break  

In the other case it will load the last checkpoint with the best model:

 model.load_state_dict(torch.load('checkpoint.pt'))  

Another way we can think of early stopping is that it gives a smaller neural network that has less capacity but still gives better generalization.


Further Research

For those of you who got interested in the topic, a new gitHub project about Argument Mining Across Heterogeneous Datasets will be linked here in September. It features argument mining implemented with BERT using Huggingface Transformer library and PyTorch, where you can see an example of applying Early Stopping in a more complex environment.


Conclusion

We have learned that stopping a neural network training early before it overfits the training data set can minimize overfitting and improve the neural network generalization capacities.


Further recommended readings:

GitHub: Early Stopping with PyTorch

Early Stopping - but when?

siegel.work

It's AI Against Corona

2019-nCoV There has been a lot of talking about the new corona virus going around the world. Let's clear up some things about it first and then we will see how data science and ai can help us fight 2019-nCoV. ...

Activation Functions

What are activation functions in Neural Networks? First of all let's clear some terminology you need in order to understand the concept of an activation function. ...

Backpropagation

or backward propagation of errorsis another supervised learning optimization algorithm. The main task of the backpropagation algorithm is to find optimal weights in a by implementing optimization technique. ...

CNNs

The Convolutional Neural Network (CNN) architecture is widely used in the field of computer vision. Because we have a massive amount of data in image files, the usage of traditional neural networks wouldn't give much efficiency as the computational time would expl...

GAN

Generative Adversarial Networks (GANs) are a type of unsupervised neural networks. The network exists since 2014 and was developed by and colleges. ...

Gradient Descent

Hiking Down a Mountain Gradient Descent is a popular optimization technique in machine learning. It is aimed to find the minimum value of a function. ...

Introduction to Statistics

Part III In this third and last part of the series "Introduction to Statistics" we will cover questions as what is probability and what are its types, as well as the three probability axioms on top of which the entire probability theory is constructed. ...

Introduction to Statistics

Part I In the following three parts we will cover basic terminology as well as the core concepts from statistics. In this Part I you are going to learn about measures of central tendency (mean, median and mode). In the Part II you will read about measures of variabili...

Introduction to Statistics

Part II In this part we will continue our talk about descriptive statistics and the measures of variability such as range, standard deviation and variance as well as different types of distributions. Feel free to read the Part I of these series to deepen your knowle...

Logistic Regression

Logit Regression Logit regression is another shortened name derived from logistic unit. Logistic regression is a popular statistical model that generates probabilities for binary classification tasks. It produces discrete values and its span lies in the range of [...

Loss Functions

When training a neural network, we try to optimize the algorithm, so it gives the best possible output. This optimization needs a loss function to compute the error/loss of the model. In this article we will gain a general picture of Squared Error, Mean Sq...

The Magic Behind Tensorflow

Getting started In this article we will delve into the magic behind one of the most popular Deep Learning frameworks - Tensorflow. We will look at the crucial terminology and some core computation principles we need to grasp the real power of Tensorflow. ...

Classification with Naive Bayes

The Bayes' Theorem describes the probability of some event, based on some conditions that might be related to that event. ...

Neural Networks

Neural Networks - Introduction In Neural Networks (NNs) we try to create a program which is able to learn from experience with respect to some task. This program should cons...

PCA

Principal component analysis or PCA is a technique for taking out relevant data points (variables also called components or sometimes features) from a larger data set. From this high dimensional data set, PCA tries extracting low dimensional data points. The idea...

Introduction to reinforcement learning

Part IV: Policy Gradient In the previous articles from this series on Reinforcement Learning (RL) we discussed Model-Based and Model-Free RL. In model-free RL we talked about Value Function Approximation (VFA). In this Part we are going to learn about Policy Based R...

Introduction to Reinforcement Learning

Part I : Model-Based Reinforcement Learning Welcome to the series "Introduction to Reinforcement Learning" which will give you a broad understanding about basic (and not only :) ) techniques in the field of Reinforcement Learning. The article series assumes you have s...

Introduction to Reinforcement Learning

Part II : Model-Free Reinforcement Learning In this Part II we're going to deal with Model-Free approaches in Reinforcement Learning (RL). See what model-free prediction and control mean and get to know some useful algorithms like Monte Carlo (MC) and Temporal Differ...

Recurrent Neural Networks

RNNs A Recurrent Neural Network (RNN) is a type of neural network where an output from the previous step is given as an input to the current step. RNNs are designed to take an input series with no size limits. RNNs remember the past states and are influenced by them...

SVM

Support Vector Machines If you happened to have a classification, a regression or an outlier detection task, you might want to consider using Support Vector Machines (SVMs), a supervised learning model, that builds a line (hyperplane) to separate data into groups....

Singular Value Decomposition

Matrix factorization: Singular Value Decomposition Matrix decomposition is another name for matrix factorization. This method is a nice representation for applied linear algebra in machine learning and similar algorithms. ...

Partial Derivatives and the Jacobian Matrix

A Jacobian Matrix is a special kind of matrix that consists of first order partial derivatives for some vector function. The form of the Jacobian matrix can vary. That means, the number of rows and columns can be equal or not, denoting that in one case it is a squa...

Introduction to Reinforcement Learning

Part III: Value Function Approximation In the previous Part I and Part II of this series we described model-based and model-free reinforcement learning as well as some well known algorithms. In this Part III we are going to talk about Value Function Approximation: w...

Weight Initialization

How does Weight Initialization work? As a general rule, weights and biases are normally initialized with some random numbers. Weights and biases are extremely important model's parameters and play a pivot role in every neural network training. Therefore, one should ...

Word Embeddings

Part 1: Introduction to Word2Vec Word embedding is a popular vocabulary representation model. Such model is able to capture contexts and semantics of a word in a document. So what is it exactly? ...

Word Embeddings

Part 2: Word2Vec (Skip Gram)In the second part of Word Embeddings we will talk about what are the downsides of the Word2Vec model (Skip Gram...

t-SNE

T-Distributed Stochastic Neighbor Embedding If you do data analysis, machine learning or some other data driven research you will prob...
Copyright © 2024 by Richard Siegel at siegel.work Donate Contact & Privacy Policy