Bootstrap and Cross-Validation Using Tidymodels

Mon, Mar 28, 2022 2-minute read

This blog post follows up Julia Silge’s penguin blog post (link). In her post, she only uses the bootstrap method, but here in this data analysis, I will use both bootstrap and 10-fold cross validation for the purposes of practicing the tidymodels meta-package.

library(tidyverse)
library(tidymodels)
library(palmerpenguins)
penguins <- penguins %>%
  filter(!is.na(sex))

Split the data into training and testing:

set.seed(2022)

split <- initial_split(penguins, strata = sex)
penguins_train <- training(split)
penguins_test <- testing(split)

Penguin recipe:

pen_rec <- recipe(sex ~ ., data = penguins_train) %>%
  update_role(island, year, new_role = "id") 

# pen_prep <- pen_rec %>%
#   prep()

pen_mod <- logistic_reg() %>%
  set_engine("glm")

Bootstrap:

pen_boots <- bootstraps(penguins_train, times = 200)

10-fold CV

pen_folds <- vfold_cv(penguins_train, v = 10)
boot_wf <- workflow() %>%
  add_recipe(pen_rec) %>%
  add_model(pen_mod) %>%
  fit_resamples(
    resamples = pen_boots,
    control = control_resamples(save_pred = TRUE, save_workflow = TRUE)
  )
boot_wf %>%
  collect_metrics()
## # A tibble: 2 x 6
##   .metric  .estimator  mean     n std_err .config             
##   <chr>    <chr>      <dbl> <int>   <dbl> <chr>               
## 1 accuracy binary     0.899   200 0.00182 Preprocessor1_Model1
## 2 roc_auc  binary     0.962   200 0.00105 Preprocessor1_Model1
boot_wf %>%
  collect_predictions() %>%
  group_by(id) %>%
  roc_curve(sex, .pred_female) %>%
  ungroup() %>%
  ggplot(aes(1-specificity, sensitivity, color = id)) +
  geom_path(show.legend = F, alpha = 0.5, size = 1) 

cv_wf <- workflow() %>%
  add_recipe(pen_rec) %>%
  add_model(pen_mod) %>%
  fit_resamples(
    resamples = pen_folds,
    control = control_resamples(save_pred = TRUE, save_workflow = TRUE)
  )
cv_wf %>%
  collect_metrics()
## # A tibble: 2 x 6
##   .metric  .estimator  mean     n std_err .config             
##   <chr>    <chr>      <dbl> <int>   <dbl> <chr>               
## 1 accuracy binary     0.908    10  0.0198 Preprocessor1_Model1
## 2 roc_auc  binary     0.964    10  0.0137 Preprocessor1_Model1
cv_wf %>%
  collect_predictions() %>%
  group_by(id) %>%
  roc_curve(sex, .pred_female) %>%
  ungroup() %>%
  ggplot(aes(1 - specificity, sensitivity, color = id)) +
  geom_path(show.legend = F, alpha = 0.6, size = 1) +
  coord_equal()

It seems like cross validation performs slightly better than bootstrap.

Without resampling, let’s see how it performs:

penguin_fit <- workflow() %>%
  add_model(pen_mod) %>%
  add_recipe(pen_rec) %>%
  fit(penguins_train) 

penguin_fit %>%
  predict(penguins_train) %>%
  bind_cols(penguins_train) %>%
  metric_set(accuracy, sensitivity, specificity)(truth = sex, estimate = .pred_class)
## # A tibble: 3 x 3
##   .metric     .estimator .estimate
##   <chr>       <chr>          <dbl>
## 1 accuracy    binary         0.924
## 2 sensitivity binary         0.919
## 3 specificity binary         0.929
penguin_fit %>%
  predict(penguins_train, type = "prob") %>%
  bind_cols(penguins_train) %>%
  roc_curve(sex, .pred_female) %>%
  autoplot() 

Testing the model:

penguin_fit %>%
  predict(penguins_test) %>%
  bind_cols(penguins_test) %>%
  accuracy(sex, .pred_class)
## # A tibble: 1 x 3
##   .metric  .estimator .estimate
##   <chr>    <chr>          <dbl>
## 1 accuracy binary         0.929
penguin_fit %>%
  predict(penguins_test, type = "prob") %>%
  bind_cols(penguins_test) %>%
  roc_curve(sex, .pred_female) %>%
  autoplot()