How to Train a Machine Learning Model in JASP: Classification

In this blog post, we show how to train a classification model using JASP’s newly released Machine Learning Module. The goal of a classification task is to predict a categorical target variable based on a (possibly large) set of features/predictors. For instance, based on different concentrations of proteins, a medical specialist might want to classify tissue as “benign” or “malignant”. Similarly, a postal sorting machine has to recognize postcodes and classify handwritten digits to be “0”, “1”, “2”,…, “8”, “9”. Here, we use customer churn as an example.

Customer Churn

Customer churn occurs when a company’s customers cancel or unsubscribe from the company’s service. Also known as customer attrition, customer churn is an important metric for a company, as it is much less expensive to retain existing customers than it is to acquire new customers. Predicting whether a customer is likely to leave a company or not is crucial, as identifying potential churning customers beforehand can lead to proactive actions to make them stay. For this task, we use the publicly available and well-known Telco Custormer Churn data set from Kaggle. The data consists of 7,043 cases/observations (rows in the data set). Each case contains a customer’s features (i.e., predictors), which includes demographic information, phone contract details, and the target Churn with outcomes “Yes” or “No”. There are 11 observations with missing values, which are discarded in the following analysis.

Supervised Learning

The main assumption is that there is a relationship, that is, a true function f, between the customer’s features x_{1}, x_{2}, \ldots, x_{p} and the categorical target y (i.e., Churn with outcomes “Yes” or “No”). Mathematically, this is expressed as

    \[y = f(x_{1}, x_{2}, \ldots, x_{p}) + \epsilon,\]

where \epsilon represents idiosyncratic error. Once we know the relationship f, then we can give it customer’s features and then predict whether the customer churns or not. However, we do not know how the customer’s features are related to churning, hence, f is unknown and we have to learn (i.e., estimate) it from the data.

To learn the true function f, we show the machine learning algorithm part of the data –the so-called training set– that has both the predictors (i.e., the customer’s characteristic), and the target, i.e., the actual observed outcome “Yes” and “No” of Churn. Because the algorithm estimates f using part of the data that contains the actual outcomes (Churn: “Yes” or “No”), we say that we are dealing with a supervised learning problem.

To test whether the learned function f generalizes well to data the algorithm has not yet seen before, we do not show it the so-called “holdout test data” in the learning phase. In JASP, we set aside 20% of the data for testing by default, thus, 1,406=7,043 x 0.20 cases. Hence, supervised learning is based on the remaining part of the data, thus, 80%, (5,626 cases) with which the machine learning algorithm tries to find the best fitting function f amongst a large collection of candidate functions. Each algorithm “Boosting Classification”, “K-Nearest Neighbors Classification”, “Linear Discriminant Classification”, and “Random Forest Classification” defines a different collection of candidate functions. We illustrate the ideas further with the “K-Nearest Neighbors Classification” algorithm.

Before running the analysis, let’s explore the data using Descriptives. Follow along using the .jasp file from the Data Library (Open – Data Library – 11. Machine Learning – Telco Customer Churn).
 

Exploring the Data

The target variable is Churn, and represents whether a customer left the company (“Yes”) or not (“No”). A quick peek at the division of churning customers in the data set using the Descriptives analysis tells us that we have more non-churning customers in our data set than churning ones. The left pane of Figure 1 shows the distribution of churning in the full data set.

To ensure we compute reliable performance statistics for our model later on, we created a balanced test set consisting of an equal number of non-churning and churning customers. To do this, we added a custom made indicator called testIndicator that represents 20% of the data that contains an equal proportion of churning to non-churning customers. The middle and right panel of Figure 1 shows the distribution of churning in the indicated equal data, and in the remainder of the data respectively.
 


Figure 1: Distribution of the target variable Churn in the full data (left), in the holdout test data (middle) and training and validation data set (right).

We prepare the data by selecting which variables we will use as predictors for the model. For the problem at hand, all variables except for the customer identification number (customerID) and the testIndicator can provide information about the target variable Churn.

Building a First Model – K-Nearest Neighbors

To simplify matters, we build an initial model with K=3 nearest neighbors, before we consider the general setting with unknown K. We elaborate on the algorithm based on the output.
 

To train the algorithm in JASP, select the Machine Learning menu, followed by K-Nearest Neighbors Classification. Open the “Data Split Preferences” section, select “Test set indicator” and allocate the variable testIndicator using the dropdown menu, so the algorithm retains 20% of the data consisting of an equal number of churners and non-churners for testing. To tell the algorithm to use K=3 neighbors open the “Training Parameters” section, go to “Number of Nearest Neighbors”, select “Fixed” and enter 3. Now go to the top and select all variables except for Churn, customerID and testIndicator, and press the arrow button next to the “Predictors” box. Lastly, add Churn into the “Target” box. Once the target and the predictor variable boxes are occupied, JASP immediately starts computing and the results can be found in Table 1.
 

Table 1.

The table shows that 5,626 observations (80% of 7,032) were used for training our K=3-nearest neighbor model, and our 1,406 observations in the test set indicator variable were used for deriving a prediction error. The data split information is also visualized with the bar shown in Figure 2.
 


Figure 2.

The output shows that our K=3-nearest neighbor model achieved a test set classification accuracy of 0.648, meaning that we can predict 64.8% of our holdout test data correctly.

The confusion matrix provides further insight into the prediction accuracy on the test set and shows how the included observations were predicted by the model. Table 2 shows that the algorithm correctly predicted 611 “No”s and 300 “Yes”s, resulting in a test accuracy of 911 / 1406 = 0.648.


Table 2.

In general, the training set allows us to make hard predictions for a new case of the form “Yes, this customer is going to churn” or “No, this customer is not going to churn”. The test set allows us to quantify the uncertainty with which the prediction is made. For a new case (not in the training, nor the test set), we can now, based on the confusion matrix, make more refined predictions of the form

    “Yes, this customer is going to churn with 76.5% chance”,

since 300/(300+92) = 0.765, and

    “No, this customer is not going to churn with 60.3% chance”,

since 611/(611+403). Note that this asymmetry might be a bit conservative due to the fact that we have relative more churners in our test set compared to the training set; compare the middle and right panel of Figure 1.

Explaining the Algorithm

The K-Nearest Neighbors algorithm takes the features of a specific customer from our holdout test set, say, Tammy and considers the K number of customers that are closest to Tammy’s features/predictors. Suppose K=3 and that Beth (Churn: “Yes”), Jerry (Churn: “No”) and Summer (Churn: “No”) from the training set are the three neighbors that are closest to Tammy feature-wise, then the algorithm takes the majority vote to predict whether Tammy churned or not. As it is two “No”s against one “Yes”, the algorithm predicts “No”. Hence, Tammy’s case goes into the “No” column of the confusion matrix. Note that this prediction does not use Tammy’s response –it only depends on the responses from cases/observations of the training set that are close to Tammy’s features/predictors.

To evaluate the performance of this prediction, we now compare Tammy’s actual response to the predicted “No”. If the observed response was also “No”, then Tammy’s case contributes to the top-left cell of the confusion matrix and this specific prediction is done correctly. On the other hand, if Tammy’s real response was “Yes”, we made an error and this adds to the bottom-left cell of the confusion matrix.

The Distance Parameter

All model parameters for the analysis can be found under the Training Parameters section. The parameter distance defines what is meant by nearest, and we can select whether to use the Euclidean distance or the Manhattan distance. To clarify the role of distance, suppose we only have two predictors/features “Total charges” and “Age”, which are depicted on the horizontal and vertical axis respectively in Figure 3.


Figure 3.

 
Figure 3 also depicts Tammy’s features with a “T”, as well as Beth’s “B”, Jerry’s “J”, Summer’s “S”, and Mortimer’s “M”. To find Tammy’s K=3 nearest neighbors in Euclidean distance, we draw circles around Tammy’s feature with increasingly larger radii, as shown in Figure 4.


Figure 4.

 
The idea is similar to detecting objects using sonar and shows that Beth is closest to Tammy feature-wise, followed by Jerry and Summer. Note that the K=3 nearest neighbors are found after five iterations. For completeness, the Euclidean distance is defined as

    \[\sqrt{|x_{1}|^{2} + |x_{2}|^{2}} = (|x_{1}|^{2} + |x_{2}|^{2})^{1/2}\]

Where x_{1} represents the difference from Tammy’s measurement of “Total charges” and x_{2} the difference from Tammy’s “Age”.

Now using the fact that blue in the figures represent a churned customer (“Yes”) and red a non-churned customer (“No”) in our training set, the algorithm now predicts “No” for Tammy.

On the other hand, to find Tammy’s K=3 nearest neighbors in the Manhattan distance, we replace the circles by diamonds, as shown in Figure 5.


Figure 5.

 
In terms of the Manhattan distance, Tammy’s K=3 nearest neighbors are Beth, Jerry, and Mortimer, and it took six iterations to identify them. Note that these K=3 neighbors represent two churned and one non-churned customers. Hence, based on the Manhattan distance, the algorithm would now predict “Yes” for Tammy. For completeness, the Manhattan distance is defined as

    \[|x_{1}|^{1} + |x_{2}|^{1} = (|x_{1}|^{1} + |x_{2}|^{1})^{1/1},\]

and note that this is similar to the Euclidean distance, but with the two in the exponent replaced by a one.

The Weights Parameter

In the “Training Parameters” section you can also find a “Weights” parameter. By weighting, we deem the cases/observations closest to Tammy in feature space more important than neighbors that are far away. Consider Figure 3 again and note that Beth is much closer to Tammy compared to Jerry and Summer. Hence, the weighting scheme can make Beth’s single “Yes” much more important than Jerry’s and Summer’s two “No”s, resulting in the algorithm predicting “Yes” instead of “No”. For a definition of the different weighting schemes, see the help file,

and for a more detailed exposition we refer to Hechenbichler & Schliep (2004).

The Number of Nearest Neighbors

Let us now focus on the parameter K, the number of nearest neighbors. In general, we do not know how many nearest neighbors contain vital information to correctly predict a new case/observation. Recall that we can correctly predict with 64.8% when K=3, as is shown in Table 1. By increasing K we use more information, and one would expect that we can then do better. This is true to a certain extent. To see this in action, we change the number of nearest neighbors from three to five. Table 3 shows the result.
 


Table 3.

The K=5-nearest neighbor model has a slightly higher test accuracy compared to the K=3-nearest neighbor model. The performance is increased further with a K=16-nearest neighbor model
 


Table 4.

and this trend suggests that we should just keep on setting K larger. However, the results of a K=48-nearest neighbor model shows that the test accuracy is now decreased
 


Table 5.

Manually increasing the number of neighbors and finding the optimal K such that the test accuracy is as high as possible is tedious. Furthermore, we cannot use the test accuracy to complement our prediction as a measure of uncertainty, since we used the test data to learn the model. To correct for this, we should cut the learning data into two: A “pure” training set, and a so-called validation set that is used to select K.

To have the algorithm automatically find the model with the best K in JASP, go to the “Training Parameters” section, “Number of Nearest Neighbors” and choose “Optimized”. By default the maximum number of neighbors consider is 50. When selected note the appearance of the “Training and Validation Data” option in the “Data Split Preference” section and that the visualization in the output changes to
 


Figure 6.

With a maximum number of neighbors of 50, the algorithm is now run 50 times: For K=1 it uses the training set, visualized as blue, and evaluate its prediction based on the validation set, visualized as orange. This is then also done for K=2 and so on, until K=50. This results in
 


Table 6.

To first clarify the number of samples “n(Train)”, “n(Validation)”, “n(Test)”, note that with 20%, “n(Test)=1,406”, held out, there are still 7,032 – 1,406 = 5,626 cases/observations left for learning. As shown under “Training and Validation Data”, a further 20%, that is, 0.2 x 5,626 = 1,126 cases are used for validation, that is, to learn and select K.

The algorithm found K=21, as it is the model with the highest validation accuracy, namely, 82.9%. Note that while a model with K=21 performs well on the validation data, it doesn’t do that much better for our specific test set. The classification accuracy plot in Figure 7 visualizes the performance of the 50 runs of the algorithm (based on the training and validation set) for the models with K=1, K=2, \ldots, K=50.
 


Figure 7.

The red dot represents the number of nearest neighbors with the highest validation classification accuracy. The number of neighbors K relates to how flexible the model is, and setting it incorrectly might result in a model that captures idiosyncratic noise instead of structure. Highly flexible models tend to yield predictions that vary too much, whereas overly simple models tend to be biased, i.e., not capture the true relationship between the predictors and the target. Hence, there is a trade-off between simplicity and generalizability.

Cross-Validation

Recall that machine learning algorithms are data-driven. This implies that (1) the results are as good as the data it is given, and that (2) each time a different training, validation or test set is chosen, the results differ. The latter occurs at each run of the algorithm because each time a different randomly selected part of the data is used as a validation set. To have the algorithm choose the same “randomly” selected validation set, we fixed a so-called seed to 1. This option is found in the “Training Parameters” section. By default, we do not fix a seed, because machine learning algorithms are intrinsically random, and we do not want to create the false impression that they are deterministic.

A better way to make the algorithm less sensitive to the data is to cross validate. This computationally intensive option is found in the “Data Split Preference” section and can be activated by checking K-fold with, say, 5 fold (this K does not refer to the number of nearest neighbors). The algorithm then first randomly partitions the training data into 5 parts and then runs the algorithm with each of these parts as the validation set to learn the number of neighbors K. Hence, this means 5 times as many computations as when we run it with a single validation set, which already required 50 runs since max K was 50. These 250 runs of the algorithm lead to
 


Table 7.

As each case is sometimes used as a validation and sometimes as a training observation, the divide between training and validation disappears, which is why the data split is now visualized as
 


Figure 8.

For the data at hand, the largest number of folds would be 5,626, which coincides with 80% of the data that is used to learn the model. Doing K-fold cross-validation with 5,626 folds is the same as leave-one-out cross-validation, which can also be directly selected in the “Data Split Preferences” section. The cross-validated results for the data at hand do not show much improvement.

Making Predictions – Looking Forward

Going back to the K nearest neighbor algorithm with a fixed training and validation set, we got a model with K=21 nearest neighbors, and test accuracy of 68.1%. Further options to explore the output come in the form of additional tables. For example, performance metrics, like the Precision, Recall, F1-score, and area under the ROC curve (AUC) can be viewed in the evaluation metrics table. Under “Plots”, we can, amongst other things, request ROC curves for each outcome of the target, in this case, two: “Yes” and “No”.
 


Figure 9.

We end this post with a model that strikes a balance between flexibility and generalizability, and we can use this model to predict whether our future telecom customers are on the verge of churning.

In a future release, we include the option to save the trained model to subsequently apply it to a new data set. For now, we hope we can make you happy with the currently implemented classification analyses: Boosting, K-Nearest Neighbors, Linear Discriminant Analysis (LDA), and Random Forest.

This post is part of a three-part series on machine learning, be sure to also check out the other upcoming posts on regression and clustering!

References

Hechenbichler K. and Schliep K.P. (2004) Weighted k-Nearest-Neighbor Techniques and Ordinal Classification, Discussion Paper 399, SFB 386, Ludwig-Maximilians University Munich

James, G., Witten, D., Hastie, T., & Tibshirani, R. (2013). An Introduction to Statistical Learning. Springer Texts in Statistics.


 

Like this post?

Subscribe to our newsletter to receive regular updates about JASP including our latest blog posts, JASP articles, example analyses, new features, interviews with team members, and more! You can unsubscribe at any time.

About the authors

Alexander Ly

Alexander Ly is the CTO of JASP and responsible for guiding JASP’s scientific and technological strategy as well as the development of some Bayesian tests.

Koen Derks

Koen Derks is a PhD candidate at Nyenrode Business University and at the Psychological Methods group at the University of Amsterdam. At JASP, he is creating JfA, an add-on module for Auditing.