library(DescTools) # for some combinatorics functions.
library(doParallel) #make the training run much faster
library(tictoc) #timing things
library(pdp) #partial dependence plots
data(ad_data) # Alzheimer's disease data
dataset <- ad_data # Replace with read_csv() or fread()
outcome_var <- "Class"
tbl <- table(dataset[[outcome_var]]) #class to predict is "Class"
positive_class <- names(tbl)[1]
dataset %>% sapply(class) %>% table # types of variables
factor integer numeric
2 2 127
trn_tst_split <-
initial_split(dataset %>%
strata = outcome_var) # Stratify on outcome
dataset_train <- training(trn_tst_split)
dataset_test <- testing(trn_tst_split)
folds <-
strata = outcome_var) # Stratify on outcome
# WT <- map(dataset_train %>% select(where(is.numeric)),
# ~wilcox.test(.x[which(dataset_train$Class == "Impaired")],
# .x[which(dataset_train$Class == "Control")])$p.value) %>% unlist %>% sort()
# Linear regression
lr_spec <-
logistic_reg() %>%
# Random forest
rf_spec <-
rand_forest(mtry = tune(),
min_n = tune(),
trees = 1000) %>%
set_engine("ranger") %>%
## Simple formulas
# formulas <-
# list("Class ~ tau + Ab_42",
# "Class ~ 1",
# "Class ~ .") %>%
# lapply(as.formula)
# 1 level of variable selection using custom function
# function to generate all possible combinations of formulas with n variables added to a base formula
#models with gender + 1 variable
formulas <- add_n_var_formulas(
y_var = outcome_var,
x_vars = c(
# start with gender and age
data = dataset_train,
recipes <-
map(formulas, function(form) {
data = dataset_train) %>%
#step_naomit(all_predictors(),skip = TRUE)
# step_impute_median() plays nice with ranger rf models, step_naomit() does not.
model_workflows <-
preproc = recipes,
models = list(#rf = rf_spec, #uncomment to make rf go as well
lr = lr_spec)
grid_ctrl <-
save_pred = TRUE,
parallel_over = "everything",
save_workflow = TRUE
model_file <- "./saved_models/test_one_term_models.Rds"
if (file.exists(model_file)){
model_workflows <- readRDS(file = model_file)
} else {
# Train all the models
#- takes 3min for one variable, 129 models on 12 cores.
all_cores <- parallel::detectCores(logical = TRUE)
cl <- makePSOCKcluster(all_cores)
tic(msg = "model training time")
model_workflows <-
model_workflows %>%
seed = 1503,
grid = 25,
control = grid_ctrl,
resamples = folds,
verbose = TRUE)
saveRDS(model_workflows, file = model_file)
roc_values <-
model_workflows %>%
collect_metrics(summarize = FALSE) %>%
filter(.metric == "roc_auc")
summary_stats <-
roc_values %>%
group_by(wflow_id, model) %>%
ROC = mean(.estimate),
lower = quantile(.estimate,probs = 0.2),
upper = quantile(.estimate,probs = 0.8),
.groups = "drop"
) %>%
model = if_else(grepl("All Predictors",wflow_id),"All Predictors",model),
model = if_else(grepl("Base Model",wflow_id),"Base Model",model),
wflow_id = factor(wflow_id),
wflow_id = reorder(wflow_id, ROC)
# Plot the summary
summary_stats %>%
ggplot(aes(x = ROC, y = wflow_id, color = model)) +
labs(color = "Model type",
title = "Comparing models of the form `Class ~ male + age +...`",
subtitle = "To separate Alzheimer's patients from healthy controls.",
x = "Area under ROC curve",
y = "Additional Variable")+
geom_point() +
geom_vline(xintercept = .5)+#line of predicting no better than chance
geom_errorbar(aes(xmin = lower, xmax = upper), width = .25) +
model_subset <-
summary_stats %>%
filter(model == "All Predictors" |
model == "Base Model" |
ROC>=quantile(ROC, 0.8) )
model_subset %>%
ggplot(aes(x = ROC, y = wflow_id, color = model)) +
labs(color = "Model type",
title = "Comparing models of the form `Class ~ male + age +...`",
subtitle = "To separate Alzheimer's patients from healthy controls.",
x = "Area under ROC curve",
y = "Additional Variable")+
geom_point() +
geom_errorbar(aes(xmin = lower, xmax = upper), width = .25) +
geom_vline(xintercept = .5)+
It looks like a logistic regression of the form
Class ~ male + age + tau
is better than the kitchen sink model with
all predictors.
wfl_id = "tau_lr",
mtr = "roc_auc",
splt = trn_tst_split){
best_results <-
workflowset %>%
extract_workflow_set_result(wfl_id) %>%
select_best(metric = mtr)
best_test_results <-
workflowset %>%
extract_workflow(wfl_id) %>%
finalize_workflow(best_results) %>%
last_fit(split = splt) %>%
mutate("wflow_id" = wfl_id)
class(best_test_results) <- c("last_fit", class(best_test_results))
best_test_results <-
wfl_id = "tau_lr",
mtr = "roc_auc",
splt = trn_tst_split)
winning_glm <- best_test_results %>% extract_fit_engine() # gets resulting glm
stats::glm(formula = ..y ~ ., family = stats::binomial, data = data)
Deviance Residuals:
Min 1Q Median 3Q Max
-2.4850 -0.7974 0.4556 0.7312 1.6590
Estimate Std. Error z value Pr(>|z|)
(Intercept) 168.2070 123.1545 1.366 0.172
male -0.5209 0.3247 -1.604 0.109
age -158.0884 125.4406 -1.260 0.208
tau -1.8760 0.3426 -5.475 4.37e-08 ***
Signif. codes: 0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
(Dispersion parameter for binomial family taken to be 1)
Null deviance: 291.98 on 248 degrees of freedom
Residual deviance: 239.80 on 245 degrees of freedom
AIC: 247.8
Number of Fisher Scoring iterations: 4
In this case age and gender don’t seem significant anymore but perhaps they need to be included for other reasons, e.g. the researcher wants them controlled for, or there is previous research which says they have an effect.
winning_glm %>%
pdp::partial(pred.var = c("tau",
train = dataset_train) %>%
pdp::plotPartial(contour = TRUE,
col.regions = colorRampPalette(c("red", "white", "blue")),
chull = TRUE,
train = dataset_train)
The black like is the convex hull of the training data, i.e. don’t trust the model outside the black line as you are then extrapolating.
best_test_results$.predictions[[1]] %>%
{roc_curve(data = .,!!rlang::data_sym(paste0(".pred_",positive_class)),
truth = !!rlang::data_sym(outcome_var))} %>%
make_test_set_roc_df <- function(model_workflows,
split = trn_tst_split,
win_models = c("All Predictors_rf",
"All Predictors_lr",
"Base Model_rf",
"Base Model_lr",
mtr = "roc_auc") {
winner_results <-
win_models %>%
workflowset = model_workflows,
wfl_id = .x,
mtr = mtr,
splt = split))
winner_results %>%
dplyr::select(.predictions,wflow_id) %>%
unnest(cols = .predictions) %>%
group_by(wflow_id) %>%
roc_curve(data = .,
truth = !!rlang::data_sym(outcome_var)) %>%
winner_results %>%
mutate(roc_a = map_dbl(.metrics,~.x$.estimate[.x$.metric==mtr])) %>%
dplyr::select(roc_a, wflow_id)
roc_curves<-inner_join(a,b,by = "wflow_id")
roc_curves %>%
arrange(desc(roc_a),sensitivity,.threshold) %>%
mutate(wflow_id = as.factor(wflow_id),
wflow_id = forcats::fct_inorder(f = wflow_id,ordered = T))
test_roc_curves <-
model_workflows %>%
make_test_set_roc_df(split = trn_tst_split,
win_models = model_subset$wflow_id %>% as.character(),
mtr = "roc_auc")
test_roc <-
test_roc_curves %>%
dplyr::select(wflow_id, roc_a) %>%
test_roc_curves %>%
ggplot(.,aes(y = sensitivity,
x = 1-specificity,
color = wflow_id,
label = round(.threshold,1)))+
geom_text(data = . %>%
mutate(rnd = round(.threshold,1)) %>%
group_by(wflow_id, rnd) %>%
labs(title = "ROC curves for separating Alzheimer's patients from healthy controls",
subtitle = "Double click on the legend to only see that curve",
color = "Model (double click one...)",
y = "Sensitivity",
x = "1 - Specificity")+
coord_fixed(ratio = 1,expand = FALSE)+
geom_abline(slope = 1, intercept = 0)+
theme(legend.position = "right")
p#%>% plotly::ggplotly()
#htmlwidgets::saveWidget(p %>% plotly::ggplotly(), "saved_widgets/ROC.html", selfcontained = F, libdir = "lib")
summary_stats %>%
dplyr::filter(as.character(wflow_id) %in% model_subset$wflow_id)%>%
ggplot(aes(x = ROC, y = wflow_id, color = model)) +
labs(color = "Model type",
title = "Comparing models of the form `Class ~ age + male +...`",
subtitle = "To separate Alzheimer's patients from healthy controls. \nColour error bars are training set AUCs, Label is test set AUC.",
x = "Area under ROC curve",
y = "Additional Variable")+
geom_point() +
geom_vline(xintercept = .5)+
geom_errorbar(aes(xmin = lower, xmax = upper), width = .25) +
geom_point(data = test_roc, mapping = aes(x = roc_a, y = wflow_id, color = NULL), size = 3) +
geom_label(data = test_roc, mapping = aes(x = roc_a, y = wflow_id, label = round(roc_a,2), color = NULL), size = 2.5)+
#ggsave("saved_images/AUCs.pdf",width = 7,height = 3.5)
So it looks like tau
wins, even in the test set.