Dataset Preparation
Import Data
suppressMessages(library(tidyverse))
data = read.csv("mushrooms.csv")
Convert categorical variables to factors
for (predictor in colnames(data)) {
data[,predictor] = as.factor(data[,predictor])
}
Specify order of ordinal variables
data$gill_spacing = ordered(data$gill_spacing,
levels=c("crowded", "close", "distant")
)
data$num_rings = ordered(data$num_rings,
levels=c(0, 1, 2)
)
Train-test Split
set.seed(0)
# Split the dataset into an 80% training set and a 20% test set.
num_observations = dim(data)[1]
train_index = sample(1:num_observations, num_observations)
num_train = num_observations * 0.8
train_set = data[train_index[1:num_train], ]
test_set = data[train_index[(num_train+1):num_observations], ]
# Split training set into the predictors matrix (X) and the target vector (y)
X_train = train_set
X_train$edibility <- NULL
y_train = train_set$edibility
# Split test set into the predictors matrix (X) and the target vector (y)
X_test = test_set
X_test$edibility <- NULL
y_test = test_set$edibility
Exploratory Data Analysis
Target Variable Distribution
data %>%
group_by(edibility) %>%
summarise(proportion = length(edibility) / nrow(data))
## # A tibble: 2 × 2
## edibility proportion
## <fct> <dbl>
## 1 edible 0.5
## 2 poisonous 0.5
Predictor-Target Distributions
suppressMessages(library(gridExtra))
# Plot the proportion of poisonous mushrooms
# within each category of each predictor
plot_bar = function(predictor) {
return(
ggplot(data, aes(x=!!rlang::enexpr(predictor), fill=edibility)) +
geom_bar(position="fill") +
coord_flip() +
ylab("") +
theme(legend.position="none")
)
}
p2 = plot_bar(bruises); p3 = plot_bar(gill_size); p4 = plot_bar(stalk_shape);
p5 = plot_bar(cap_surface); p6 = plot_bar(gill_attach); p7 = plot_bar(veil_color);
p8 = plot_bar(habitat); p9 = plot_bar(gill_spacing); p10 = plot_bar(num_rings)
grid.arrange(p2, p3, p4, p5, p6, p7, p8, p9, p10, ncol=2)

Hyperparameter Tuning
# To tune the models' hyperparameter values, I used grid search,
# 5-fold cross-validation, and AUC.
suppressMessages(library(caret))
train_control <- trainControl(
method="cv",
number=5,
summaryFunction=twoClassSummary,
classProbs=TRUE
)
Classification Models
Logistic Regression
set.seed(0)
# alpha is the balance between ridge and lasso
# lambda is the amount of regularization
hyperparam_grid <- expand.grid(
alpha = seq(0, 1, length.out=20),
lambda = seq(0, 0.5, length = 20)
)
logistic_model = train(
edibility ~ .,
method = "glmnet",
tuneGrid = hyperparam_grid,
trControl = train_control,
metric = "ROC",
data = train_set
)
# Get the best values of alpha and lambda
best_alpha = logistic_model$finalModel$tuneValue$alpha
best_lambda = logistic_model$finalModel$tuneValue$lambda
cat("The best \u03b1 value is", round(best_alpha, 2),
"and the best \u03bb value is", round(best_lambda, 2))
## The best α value is 0.05 and the best λ value is 0
# Plot the logistic regression validation heatmap
logistic_grid = logistic_model$results
ggplot(logistic_grid, aes(alpha, lambda, fill= ROC)) +
geom_tile() +
labs(x="\u03b1", y="\u03bb", fill='Average \nValidation \nSet AUC')

Random Forest
set.seed(0)
# mtry is the number of predictors to consider in each split
num_predictors = ncol(X_train)
hyperparam_grid <- expand.grid(
mtry = 1:num_predictors
)
random_f_model = train(
edibility ~ .,
method = "rf",
tuneGrid = hyperparam_grid,
trControl = train_control,
metric = "ROC",
data = train_set
)
# Get the best value of mtry
best_mtry = random_f_model$finalModel$tuneValue$mtry
cat("The best mtry value is", best_mtry)
## The best mtry value is 9
# Plot the random forest validation curve
ggplot(random_f_model$results, aes(x=mtry, y=ROC)) +
geom_errorbar(aes(ymin=ROC-ROCSD, ymax=ROC+ROCSD), width=.1) +
geom_line() +
geom_point() +
ylab("Average Validation set AUC (± SD)") +
scale_x_continuous(breaks=seq(1,num_predictors))

k-NN
set.seed(0)
# k is the no. of neighbors
k_max = 10
hyperparam_grid <- expand.grid(
k = 1:k_max
)
knn_model <- train(
edibility ~ .,
method = "knn",
tuneGrid = hyperparam_grid,
trControl = train_control,
metric = "ROC",
data = train_set
)
# Get the best value of k
best_k = knn_model$finalModel$tuneValue$k
cat("The best k value is", best_k)
## The best k value is 4
# Plot the k-NN validation curve
ggplot(knn_model$results, aes(x=k, y=ROC)) +
geom_errorbar(aes(ymin=ROC-ROCSD, ymax=ROC+ROCSD), width=.1) +
geom_line() +
geom_point() +
ylab("Average Validation set AUC (± SD)") +
scale_x_continuous(breaks=seq(1,k_max))

Model Evalution
# For the mushrooms in the test set, calculate the
# predicted probabilities of being poisonous
logistic_pred_prob = predict(logistic_model, newdata = X_test, type="prob")[,"poisonous"]
knn_pred_prob = predict(knn_model, newdata = X_test, type="prob")[,"poisonous"]
random_f_pred_prob = predict(random_f_model, newdata = X_test, type="prob")[,"poisonous"]
# Store all the predicted test set probabilities in a dataframe
pred_probs = data.frame(
logistic = logistic_pred_prob,
random_forest = random_f_pred_prob,
knn = knn_pred_prob
)
suppressMessages(library(pROC))
# Function to calculate test set AUC from the
# predicted probabilities of being poisonous
get_AUC = function(pred_prob) {
roc_obj <- roc(
response=y_test,
predictor=pred_prob,
levels=c("edible", "poisonous"),
direction=c("<"))
return(auc(roc_obj)
)
}
# Calculate the test set AUC fo each classification model
aucs = apply(pred_probs, 2, get_AUC)
sorted_aucs = sort(aucs, decreasing=TRUE)
model_order = names(sorted_aucs)
# Combine model names and AUCs into a dataframe
aucs_df = data.frame(
model = names(sorted_aucs),
auc = unname(sorted_aucs)
)
# Plot the test set AUCs for all the classification models
ggplot(aucs_df, aes(x=model, y=auc)) +
geom_bar(stat = "identity") +
coord_cartesian(ylim=c(0.85, 1)) +
scale_x_discrete(limits = model_order) +
geom_text(aes(label=round(auc, 3)), vjust=2, col="white") +
ylab("Test set AUC") +
xlab("Classification model")

Model Interpretation
set.seed(0)
# Fit a random forest model with the best mtry value.
suppressMessages(library(randomForest))
random_f_model = randomForest(
edibility ~ .,
data=train_set,
mtry=best_mtry,
importance=TRUE
)
# For each predictor P, calculate the average decrease in
# Gini index for all splits using P.
importance_random_f = data.frame(
predictor = rownames(importance(random_f_model)),
importance = as.data.frame(importance(random_f_model))$MeanDecreaseGini
)
sorted_importances = importance_random_f[
order(importance_random_f$importance, decreasing = FALSE),
]
largest_importances = tail(sorted_importances, 5)
# Plot the predictor importances
ggplot(largest_importances, aes(x=importance, y=predictor)) +
geom_col(width = 0.6) +
xlab("Importance (Average Decrease in Gini Index)") +
geom_text(aes(label=round(importance, 2)), hjust=-0.2) +
scale_y_discrete(limits = largest_importances$predictor) +
xlim(0, 133.0) +
ylab("Predictor") +
theme_grey(base_size = 13)
