Cross validation

Rafael Irizarry

A common goal of machine learning is to find an algorithm that produces predictors  for an outcome  that minimizes the MSE:

When all we have at our disposal is one dataset, we can estimate the MSE with the observed MSE like this:

These two are often referred to as the true error and apparent error, respectively.

There are two important characteristics of the apparent error we should always keep in mind:

  1. Because our data is random, the apparent error is a random variable. For example, the dataset we have may be a random sample from a larger population. An algorithm may have a lower apparent error than another algorithm due to luck.
  2. If we train an algorithm on the same dataset that we use to compute the apparent error, we might be overtraining. In general, when we do this, the apparent error will be an underestimate of the true error. We will see an extreme example of this with k-nearest neighbors.

Cross validation is a technique that permits us to alleviate both these problems. To understand cross validation, it helps to think of the true error, a theoretical quantity, as the average of many apparent errors obtained by applying the algorithm to \beta new random samples of the data, none of them used to train the algorithm. As shown in a previous chapter, we think of the true error as:

\frac{1}{B} \overset{B}{ \underset{b=1}{\sum}} \frac{1}{N}\overset{N}{ \underset{i=1}{\sum}} \left(\hat{y}_i^b - y_i^b\right)^2

with \beta a large number that can be thought of as practically infinite. As already mentioned, this is a theoretical quantity because we only have available one set of outcomes: y_1, \dots, y_n. Cross validation is based on the idea of imitating the theoretical setup above as best we can with the data we have. To do this, we have to generate a series of different random samples. There are several approaches we can use, but the general idea for all of them is to randomly generate smaller datasets that are not used for training, and instead used to estimate the true error.

K-fold cross validation

The first one we describe is K-fold cross validation. Generally speaking, a machine learning challenge starts with a dataset (blue in the image below). We need to build an algorithm using this dataset that will eventually be used in completely independent datasets (yellow).

Two bars, one is blue and longer than the other, it is labeled "Data". The other is a shorter, yellow bar labeled "Independent Set".
Data sets are created independently at first.

But we don’t get to see these independent datasets.

A long blue bar labeled "Data".
An independent data set.

So to imitate this situation, we carve out a piece of our dataset and pretend it is an independent dataset: we divide the dataset into a training set (blue) and a test set (red). We will train our algorithm exclusively on the training set and use the test set only for evaluation purposes.

We usually try to select a small piece of the dataset so that we have as much data as possible to train. However, we also want the test set to be large so that we obtain a stable estimate of the loss without fitting an impractical number of models. Typical choices are to use 10%-20% of the data for testing.

A long blue bar with a smaller section colored red. The blue section is labeled "Train", the red section is labeled "Test".
The algorithm will be trained exclusively on the training set. A test will be used only for evaluation purposes.

Let’s reiterate that it is indispensable that we not use the test set at all: not for filtering out rows, not for selecting features, nothing!

Now this presents a new problem because for most machine learning algorithms we need to select parameters, for example the number of neighbors k in k-nearest neighbors. Here, we will refer to the set of parameters as \lambda. We need to optimize algorithm parameters without using our test set and we know that if we optimize and evaluate on the same dataset, we will overtrain. This is where cross validation is most useful.

For each set of algorithm parameters being considered, we want an estimate of the MSE and then we will choose the parameters with the smallest MSE. Cross validation provides this estimate.

First, before we start the cross validation procedure, it is important to fix all the algorithm parameters. Although we will train the algorithm on the set of training sets, the parameters \lambda will be the same across all training sets. We will use \hat{y}_i(\lambda) to denote the predictors obtained when we use parameters \lambda.

So, if we are going to imitate this definition:

\mbox{MSE}(\lambda) = \frac{1}{B} \sum_{b=1}^B \frac{1}{N}\sum_{i=1}^N \left(\hat{y}_i^b(\lambda) - y_i^b\right)^2

we want to consider datasets that can be thought of as an independent random sample and we want to do this several times. With K-fold cross validation, we do it K times. In the cartoons, we are showing an example that uses K=5.

We will eventually end up with K samples, but let’s start by describing how to construct the first: we simply pick M=N/K observations at random (we round if M is not a round number) and think of these as a random sample y^b_1, \dots, y^b_M with b=1. We call this the validation set:

Two diagrams comparing a data science workflow with and without machine learning. A bar in two sections labeled "Train" and "Test" transforms into a bar in three sections labeled " Train 1", "Validate 1" and "Test".
An invalidation set.

Now we can fit the model in the training set, then compute the apparent error on the independent set:

\hat{\mbox{MSE}}_b(\lambda) = \frac{1}{M}\sum_{i=1}^M \left(\hat{y}_i^b(\lambda) - y_i^b\right)^2

Note that this is just one sample and will therefore return a noisy estimate of the true error. This is why we take K samples, not just one. In K-cross validation, we randomly split the observations into K non-overlapping sets:

Diagram illustrating cross-validation steps to optimize a machine learning model. Splits data into training and validation sets, trains the model on the training set, validates on the validation set, and repeats.
Cross validation can be used to optimize model parameters.

Now we repeat the calculation above for each of these sets b=1,\dots,K and obtain M\hat{S}E_1(\lambda),\dots,M\hat{S}E_K(\lambda). Then, for our final estimate, we compute the average:

\hat{\mbox{MSE}}(\lambda) = \frac{1}{K} \sum_{b=1}^K \hat{\mbox{MSE}}_b(\lambda)

and obtain an estimate of our loss. A final step would be to select the \lambda that minimizes the MSE.

We have described how to use cross validation to optimize parameters. However, we now have to take into account the fact that the optimization occurred on the training data and therefore we need an estimate of our final algorithm based on data that was not used to optimize the choice. Here is where we use the test set we separated early on:

Diagram of a machine learning training program. Stages include data collection, exploration, feature engineering, model selection, training, evaluation, and deployment.
Estimating a final algorithm.

We can do cross validation again:

Bar plot showing accuracy of a procedure with and without cross validation. CV can estimate projected loss.
Cross validation can be used to estimate loss.

and obtain a final estimate of our expected loss. However, note that this means that our entire compute time gets multiplied by K. You will soon learn that performing this task takes time because we are performing many complex computations. As a result, we are always looking for ways to reduce this time. For the final evaluation, we often just use the one test set.

Once we are satisfied with this model and want to make it available to others, we could refit the model on the entire dataset, without changing the optimized parameters.

Two models representing a fit data model with chose parameters to its data set, another smaller model representing an independent set of data that will be applied to other independent datasets.
We can refit the model on the entire dataset, without changing the optimized parameters.

Now how do we pick the cross validation K? Large values of K are preferable because the training data better imitates the original dataset. However, larger values of K will have much slower computation time: for example, 100-fold cross validation will be 10 times slower than 10-fold cross validation. For this reason, the choices of K=5 and K=10 are popular.

One way we can improve the variance of our final estimate is to take more samples. To do this, we would no longer require the training set to be partitioned into non-overlapping sets. Instead, we would just pick K sets of some size at random.

One popular version of this technique, at each fold, picks observations at random with replacement (which means the same observation can appear twice). This approach has some advantages (not discussed here) and is generally referred to as the bootstrap. In fact, this is the default approach in the caret package. We describe how to implement cross validation with the caret package in the next chapter. In the next section, we include an explanation of how the bootstrap works in general.

License

Icon for the Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International License

Business Analytics Copyright © by Di Shang is licensed under a Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International License, except where otherwise noted.

Share This Book