ConfusionTableR

This package allows for the rapid transformation of confusion matrix objects from the caret package and allows for these to be easily converted into data frame objects, as the objects are natively list object types.

Why is this useful

This is useful, as it allows for the list items to be turned into a transposed row and column data frame. I had the idea when working with a number of machine learning models and wanted to store the results in database tables, thus I wanted a way to have one row per model run. This is something that is not implemented in the excellent caret package created by Max Kuhn [https://cran.r-project.org/web/packages/caret/index.html].

Preparing the ML model to then evaluate

The following approach shows how the single confusion matrix function can be used to flatten all the results of the caret confusion matrices down from the simple binary classification model. This example is implemented below:

Example:

library(ggplot2)
library(caret)
library(caretEnsemble)
library(scales)
library(mltools)
library(mlbench)

# Load in the iris data set for this problem 
data(iris)
df <- iris
# View the class distribution, as this is a multiclass problem, we can use the multi classification data table builder
table(iris$Species)
#> 
#>     setosa versicolor  virginica 
#>         50         50         50
ggplot(data = iris,
       aes(x=Species)) + geom_bar(aes(fill = Species), show.legend = FALSE) + theme_minimal()

# We can see we have a balanced dataset. I will now create a simple test and train split on the data
train_split_idx <- caret::createDataPartition(df$Species, p = 0.75, list = FALSE)
# Here we define a split index and we are now going to use a multiclass ML model to fit the data
data_TRAIN <- df[train_split_idx, ]
data_TEST <- df[-train_split_idx, ]
str(data_TRAIN)
#> 'data.frame':    114 obs. of  5 variables:
#>  $ Sepal.Length: num  5.1 4.6 5 4.6 5 4.4 5.4 4.8 4.8 4.3 ...
#>  $ Sepal.Width : num  3.5 3.1 3.6 3.4 3.4 2.9 3.7 3.4 3 3 ...
#>  $ Petal.Length: num  1.4 1.5 1.4 1.4 1.5 1.4 1.5 1.6 1.4 1.1 ...
#>  $ Petal.Width : num  0.2 0.2 0.2 0.3 0.2 0.2 0.2 0.2 0.1 0.1 ...
#>  $ Species     : Factor w/ 3 levels "setosa","versicolor",..: 1 1 1 1 1 1 1 1 1 1 ...

This now creates a 75% training set for training the ML model and we are going to use the remaining 25% as validation data to test the model.

rf_model <- caret::train(Species ~ .,
                         data = df,
                         method = "rf",
                         metric = "Accuracy")

rf_model
#> Random Forest 
#> 
#> 150 samples
#>   4 predictor
#>   3 classes: 'setosa', 'versicolor', 'virginica' 
#> 
#> No pre-processing
#> Resampling: Bootstrapped (25 reps) 
#> Summary of sample sizes: 150, 150, 150, 150, 150, 150, ... 
#> Resampling results across tuning parameters:
#> 
#>   mtry  Accuracy   Kappa    
#>   2     0.9468768  0.9198443
#>   3     0.9481294  0.9217557
#>   4     0.9481852  0.9218222
#> 
#> Accuracy was used to select the optimal model using the largest value.
#> The final value used for the model was mtry = 4.

The model is relatively accurate. This is not a lesson on machine learning, however we now know how well the model performs on the training set, we need to validate this with a confusion matrix. The Random Forest shows that it has been trained on greater than >2 classes so this moves from a binary model to a multi-classification model. The functions contained in the package work with binary and multiclassification methods.

Using the native Confusion Matrix in CARET

The native confusion matrix is excellent in CARET, however it is stored as a series of list items that need to be utilised together to compare model fit performance over time to make sure there is no underlying feature slippage and regression in performance. This is where my solution comes in.

# Make a prediction on the fitted model with the test data
rf_class <- predict(rf_model, newdata = data_TEST, type = "raw") 

# Create a confusion matrix object
cm <- confusionMatrix(rf_class,
                      data_TEST[,names(data_TEST) %in% c("Species")])

print(cm) 
#> Confusion Matrix and Statistics
#> 
#>             Reference
#> Prediction   setosa versicolor virginica
#>   setosa         12          0         0
#>   versicolor      0         12         0
#>   virginica       0          0        12
#> 
#> Overall Statistics
#>                                      
#>                Accuracy : 1          
#>                  95% CI : (0.9026, 1)
#>     No Information Rate : 0.3333     
#>     P-Value [Acc > NIR] : < 2.2e-16  
#>                                      
#>                   Kappa : 1          
#>                                      
#>  Mcnemar's Test P-Value : NA         
#> 
#> Statistics by Class:
#> 
#>                      Class: setosa Class: versicolor Class: virginica
#> Sensitivity                 1.0000            1.0000           1.0000
#> Specificity                 1.0000            1.0000           1.0000
#> Pos Pred Value              1.0000            1.0000           1.0000
#> Neg Pred Value              1.0000            1.0000           1.0000
#> Prevalence                  0.3333            0.3333           0.3333
#> Detection Rate              0.3333            0.3333           0.3333
#> Detection Prevalence        0.3333            0.3333           0.3333
#> Balanced Accuracy           1.0000            1.0000           1.0000
typeof(cm)
#> [1] "list"

The outputs of the matrix are really useful, however I want to combine all this information into one row of a data frame for storing information in a data table and import into a database universe.

Using ConfusionTableR to collapse this data into a data frame

The package has two functions for dealing with these types of problems, firstly I will show the multiclassification version and show how this can be implemented:

Example

# Implementing function to collapse data

library(ConfusionTableR)
mc_df <- ConfusionTableR::multi_class_cm(cm)
print(mc_df)
#>   setosa : setosa setosa : versicolor setosa : virginica versicolor : setosa
#> 1              12                   0                  0                   0
#>   versicolor : versicolor versicolor : virginica virginica : setosa
#> 1                      12                      0                  0
#>   virginica : versicolor virginica : virginica Accuracy Kappa AccuracyLower
#> 1                      0                    12        1     1     0.9026062
#>   AccuracyUpper AccuracyNull AccuracyPValue McnemarPValue setosa : Sensitivity
#> 1             1    0.3333333   6.662463e-18           NaN                    1
#>   setosa : Specificity setosa : Pos Pred Value setosa : Neg Pred Value
#> 1                    1                       1                       1
#>   setosa : Precision setosa : Recall setosa : F1 setosa : Prevalence
#> 1                  1               1           1           0.3333333
#>   setosa : Detection Rate setosa : Detection Prevalence
#> 1               0.3333333                     0.3333333
#>   setosa : Balanced Accuracy versicolor : Sensitivity versicolor : Specificity
#> 1                          1                        1                        1
#>   versicolor : Pos Pred Value versicolor : Neg Pred Value
#> 1                           1                           1
#>   versicolor : Precision versicolor : Recall versicolor : F1
#> 1                      1                   1               1
#>   versicolor : Prevalence versicolor : Detection Rate
#> 1               0.3333333                   0.3333333
#>   versicolor : Detection Prevalence versicolor : Balanced Accuracy
#> 1                         0.3333333                              1
#>   virginica : Sensitivity virginica : Specificity virginica : Pos Pred Value
#> 1                       1                       1                          1
#>   virginica : Neg Pred Value virginica : Precision virginica : Recall
#> 1                          1                     1                  1
#>   virginica : F1 virginica : Prevalence virginica : Detection Rate
#> 1              1              0.3333333                  0.3333333
#>   virginica : Detection Prevalence virginica : Balanced Accuracy
#> 1                        0.3333333                             1
names(mc_df)
#>  [1] "setosa : setosa"                   "setosa : versicolor"              
#>  [3] "setosa : virginica"                "versicolor : setosa"              
#>  [5] "versicolor : versicolor"           "versicolor : virginica"           
#>  [7] "virginica : setosa"                "virginica : versicolor"           
#>  [9] "virginica : virginica"             "Accuracy"                         
#> [11] "Kappa"                             "AccuracyLower"                    
#> [13] "AccuracyUpper"                     "AccuracyNull"                     
#> [15] "AccuracyPValue"                    "McnemarPValue"                    
#> [17] "setosa : Sensitivity"              "setosa : Specificity"             
#> [19] "setosa : Pos Pred Value"           "setosa : Neg Pred Value"          
#> [21] "setosa : Precision"                "setosa : Recall"                  
#> [23] "setosa : F1"                       "setosa : Prevalence"              
#> [25] "setosa : Detection Rate"           "setosa : Detection Prevalence"    
#> [27] "setosa : Balanced Accuracy"        "versicolor : Sensitivity"         
#> [29] "versicolor : Specificity"          "versicolor : Pos Pred Value"      
#> [31] "versicolor : Neg Pred Value"       "versicolor : Precision"           
#> [33] "versicolor : Recall"               "versicolor : F1"                  
#> [35] "versicolor : Prevalence"           "versicolor : Detection Rate"      
#> [37] "versicolor : Detection Prevalence" "versicolor : Balanced Accuracy"   
#> [39] "virginica : Sensitivity"           "virginica : Specificity"          
#> [41] "virginica : Pos Pred Value"        "virginica : Neg Pred Value"       
#> [43] "virginica : Precision"             "virginica : Recall"               
#> [45] "virginica : F1"                    "virginica : Prevalence"           
#> [47] "virginica : Detection Rate"        "virginica : Detection Prevalence" 
#> [49] "virginica : Balanced Accuracy"
class(mc_df) # This has now been converted to a data frame
#> [1] "data.frame"

This data frame can now be used to store analyse these records over time i.e. looking at the machine learning statistics and if they depreciate or reduce upon different training runs.

Using ConfusionTableR to collapse binary confusion matrix outputs

Here we will the breast cancer dataset to perform this. The following steps in the block show how to prepare this dataset:

Example

Preparing data and fitting to model

# Implementing function to collapse data
data("BreastCancer", package = "mlbench")
#Use complete cases of breast cancer
breast <- BreastCancer[complete.cases(BreastCancer), ] #Create a copy
breast <- breast[, -1]
# The ML bench data shows the data in the mlbench package - this allows for a binary classification of benign vs malignant. We will use this for our ML model
breast$Class <- factor(breast$Class) # Create as factor
for(i in 1:9) {
  breast[, i] <- as.numeric(as.character(breast[, i]))
}

# Train a ML model to fit to it
# Split the data
train_split_idx <- caret::createDataPartition(breast$Class, p = 0.75, list = FALSE)
# Here we define a split index and we are now going to use a multiclass ML model to fit the data
data_TRAIN <- breast[train_split_idx, ]
data_TEST <- breast[-train_split_idx, ]
# Fit a logistic regression model
glm_model <- train(Class ~ Cl.thickness + Cell.size + Cell.shape + Marg.adhesion +                Normal.nucleoli,
                   data = data_TRAIN,
                   method = "glm",
                   family = "binomial")

glm_model
#> Generalized Linear Model 
#> 
#> 513 samples
#>   5 predictor
#>   2 classes: 'benign', 'malignant' 
#> 
#> No pre-processing
#> Resampling: Bootstrapped (25 reps) 
#> Summary of sample sizes: 513, 513, 513, 513, 513, 513, ... 
#> Resampling results:
#> 
#>   Accuracy   Kappa    
#>   0.9519893  0.8927816

We now have our breast cancer binary model ready. Now we will fit a confusion matrix to this and use the tools in ConfusionTableR to output to data frame and build a visualisation of the confusion matrix.

Predicting and fitting confusion matrix

This snippet shows how to achieve this:

glm_class <- predict(glm_model, newdata = data_TEST, type = "raw") 

# Create a confusion matrix object
cm <- confusionMatrix(glm_class,
                      data_TEST[,names(data_TEST) %in% c("Class")])

print(cm)
#> Confusion Matrix and Statistics
#> 
#>            Reference
#> Prediction  benign malignant
#>   benign       107         1
#>   malignant      4        58
#>                                           
#>                Accuracy : 0.9706          
#>                  95% CI : (0.9327, 0.9904)
#>     No Information Rate : 0.6529          
#>     P-Value [Acc > NIR] : <2e-16          
#>                                           
#>                   Kappa : 0.9359          
#>                                           
#>  Mcnemar's Test P-Value : 0.3711          
#>                                           
#>             Sensitivity : 0.9640          
#>             Specificity : 0.9831          
#>          Pos Pred Value : 0.9907          
#>          Neg Pred Value : 0.9355          
#>              Prevalence : 0.6529          
#>          Detection Rate : 0.6294          
#>    Detection Prevalence : 0.6353          
#>       Balanced Accuracy : 0.9735          
#>                                           
#>        'Positive' Class : benign          
#> 

Now this is where we will use the package to visualise and reduce to a data frame.

Binary Confusion Matrix Data Frame

The following example shows how this is implemented:


names(ConfusionTableR::binary_class_cm(cm))
#>  [1] "Pred_benign_Ref_benign"       "Pred_malignant_Ref_benign"   
#>  [3] "Pred_benign_Ref_malignant"    "Pred_malignant_Ref_malignant"
#>  [5] "Accuracy"                     "Kappa"                       
#>  [7] "AccuracyLower"                "AccuracyUpper"               
#>  [9] "AccuracyNull"                 "AccuracyPValue"              
#> [11] "McnemarPValue"                "Sensitivity"                 
#> [13] "Specificity"                  "Pos.Pred.Value"              
#> [15] "Neg.Pred.Value"               "Precision"                   
#> [17] "Recall"                       "F1"                          
#> [19] "Prevalence"                   "Detection.Rate"              
#> [21] "Detection.Prevalence"         "Balanced.Accuracy"           
#> [23] "cm_ts"
print(ConfusionTableR::binary_class_cm(cm))
#>   Pred_benign_Ref_benign Pred_malignant_Ref_benign Pred_benign_Ref_malignant
#> 1                    107                         4                         1
#>   Pred_malignant_Ref_malignant  Accuracy     Kappa AccuracyLower AccuracyUpper
#> 1                           58 0.9705882 0.9358684     0.9327003     0.9903825
#>   AccuracyNull AccuracyPValue McnemarPValue Sensitivity Specificity
#> 1    0.6529412   1.692566e-24     0.3710934    0.963964   0.9830508
#>   Pos.Pred.Value Neg.Pred.Value Precision   Recall        F1 Prevalence
#> 1      0.9907407      0.9354839 0.9907407 0.963964 0.9771689  0.6529412
#>   Detection.Rate Detection.Prevalence Balanced.Accuracy               cm_ts
#> 1      0.6294118            0.6352941         0.9735074 2021-04-06 17:05:17

This is now in a data.frame class and can be used and saved as a single record to a database server to monitor confusion matrix performance over time.

Visualising the confusion matrix

The last tool in the package produces a nice visual of the confusion matrix that can be used in presentations and papers to display the matrix and its associated summary statistics:


ConfusionTableR::binary_visualiseR(
  cm_input = cm, class_label1 = "Benign", class_label2 = "Malignant",
  quadrant_col1 = "#28ACB4", quadrant_col2 = "#4397D2", 
  custom_title = "Breast Cancer Confusion Matrix", text_col= "black"
)

These can be used in combination with the outputs from the CARET package to build up the analysis of how well the model fits and how well it will fit in the future, from the analysis of Cohen’s Kappa value and other associated metrics.

Getting variable importance

The variable importance var_impeR model creates a variable importance chart and tibble. To use this function see below:

ConfusionTableR::var_impeR(glm_model)
#> $model
#> # A tibble: 5 x 5
#>   rowid Metric          Overall  Prop generated          
#>   <int> <chr>             <dbl> <dbl> <dttm>             
#> 1     1 Cl.thickness      100   0.377 2021-04-06 17:05:17
#> 2     2 Cell.size           0   0     2021-04-06 17:05:17
#> 3     3 Cell.shape         52.6 0.198 2021-04-06 17:05:17
#> 4     4 Marg.adhesion      63.7 0.240 2021-04-06 17:05:17
#> 5     5 Normal.nucleoli    48.7 0.184 2021-04-06 17:05:17
#> 
#> $glob_var_imp_plot

This returns a tibble with the variable importance values and an associated chart to display the variable importance.

Wrapping up

This has been created to aid in the storage of confusion matrix outputs into a flat row wise structure for storage in data tables, frames and data warehouses, as from experience we tend to monitor the test statistics for working with these matrices over time, when they have been retrained.