Threading and Caret: Burning your CPU to improve model training speed

When I am doing a Machine Learning project with R it is crucial to save those precious seconds, minutes, hours in model training. To make sure you optimise your CPU and crank up the performance you will need to load in the parallel and doParallel libraries, alongside library(caret).

A wrapper function to help with this

I have created a wrapper function to help with this. I will stored this in a separate file, located on a shared directory, that I can call and use with my projects in R.

The bones of the function

The function has no parameters and can be called directly in the script:

1
2
3
4
5
6
7
8
9
amp_up_models <- function(){
  library(parallel)
  library(doParallel)
  no_cores <- parallel::detectCores() - 1
  #Leave one core available for Operating system
  cluster <- makePSOCKcluster(no_cores)
  registerDoParallel(cluster)
  cat("Model amped and ready to go with:", no_cores, "cores. \n")
}

This function uses the parallel library to detect the number of cores in your system and then uses the maximum (minus 1 for OS requirements) to enable threading (performing more than one task concurrently). With random forests this would be the number of mtry for example.

The makePSOCKcluster then allocates the cluster space in the CPU for the current model you are running and the registerDoParallel registers the cluster on the CPU.

Finally, the cat statement prints a friendly message to the user indicating that the model is amped and ready to go with: x cores.

Using the function

To use the function it is a simple call to the function to set things in motion:

1
amp_up_models()

This then prints the message in the cat statement.

Storing the function separately

I will now save this function independently as an R script with the file name amp_up_models.R to my working directory.

Using the function with a new project

To import the function back into my environment I will use the following code:

1
source("C:/Users/GaryHutson/Desktop/Caret and Threading/Amp up models.R")

The source function allows you to call external R function files and use them with your current project. This will create the function and automatically.

Using our new function with CARET

CARET is an extensive machine learning library, developed by Max Kuhn, who has also worked on TidyModels, Parsnip and Recipes.

This is still my go to for ML algorithms, as there are many more to choose from. I know that Max and his team are working hard to add these algorithms to Parsnip, which supports the tidy way of doing things. See: https://www.tidyverse.org/blog/2018/11/parsnip-0-0-1/ for a cool blog on what Parsnip has to offer.

For a reference, check out the list of available algorithms in caret: https://rdrr.io/cran/caret/man/models.html.

Anyway, to use our new function with CARET we need to use the option allowParallel = TRUE in the CARET options. See the below example of how to set this up in CARET:

1
2
3
4
caret_example <- train(Y ~ X1 + X2, 
                       data = my_data, 
                       method = "rf",
                       allowParallel = TRUE)

Here we create a train object native to caret and this consists of:

  • Y – the dependent variable or variables I want to predict
  • ~ meaning predict Y given anything to the right
  • X1 + X2 are my predictor variables / independent variables
  • Method equal to rf which is code for random forest
  • The key parameter is to set the allowParallel switch equal to TRUE

Once this has been activated, you will notice that your computer jumps to life and will use up many more cores. Perhaps step away from the PC for a while, as this will cause other programs to slow down.

Mind you, with modern computers, we don’t need to feed the hamster too often.

Show me the source code

The source code is available from my GITHUB page: https://github.com/StatsGary/Using-Threading-with-CARET/blob/master/Amp%20up%20models.R.

Thanks and keep an eye out for my other posts.

Copyright Dilbert 2013

Leave a Reply

Your email address will not be published. Required fields are marked *