In this post, I’d like to show you how to use the newly written package on Bayesian Additive Regression Trees i.e. BART Machine for Restaurant Revenue Prediction with R. The datasets are part of the passed Kaggle competition that can be found here.
What is BART?
BART is the Bayesian sibling of Random Forest. For some applications, Bayesian approaches with probabilistic assumptions work better than pure algorithmic ensemble of trees due to underlying prior assumptions. BART utilizes prior assumptions for parameters estimate and splits that occur in classic tree models. There are multiple distinctions between BART and Random Forest. One is the user defined ability to set generalized Bernoulli priors on features to be selected for each split, whereas in Random Forest, this prior is uniform on features. The other ones are having multiple hyperparameters controlling the growth of trees more than Random Forest and the ability to find explicit interactions among features, which can hardly be found in general Machine/Statistical learning algorithms/models.
BART in action
Speaking of competition, TFI has been one of the most unusual Kaggle competitions so far and has confused and hurt many contestants as well as myself because of the following reasons:
- Great disparity among the number of training and testing samples. The training set has 137 samples with 37 obfuscated features out of 42 with the
revenue
- as the response variable. The test set has 100,000 samples (poisoned data and robust model indications!).
- It was very hard to get consistent cross-validation scores even with repeated CV.
- Almost all of the features are weak and some are incompatible in train and test sets.
Many
- of the top seed Kaggle Masters did poorly in this competition which adds to its strangeness and perhaps if you knew less, you could have been more successful in this competition!
So, let’s briefly look at the data first;
data <- read.csv('train.csv', sep=',', stringsAsFactor=F) test <- read.csv('test.csv', sep=',', stringsAsFactor=F) revenue <- data$revenue train <- data[,-ncol(data)] all <- rbind(train, test)
As you go through exploring all dataset, you will figure out many odd cases, like the feature City, the number of unique cities used in training set length(unique(data$City))
is 34 whereas the number of unique cities used in test set is 57. You can check the forum to see more odd cases. Of course, these cases could happen in real data science and everyone should now how to handle them. This post, however, does not deal with that kind of behaviors.
By plotting the histogram of revenue, one can see its skewness easily, so the natural log(revenue)
better demonstrates its (nearly) log normal distribution and perhaps suggests existence of outliers.
library(ggplot2) qplot(log(data$revenue), geom="histogram", binwidth=0.1, main="Histogram of Revenue", xlab="revenue", fill=I("red"), col=I("black"))
To plot the correlations among (numeric) features and the response, we can use the corrplot
package
library(corrplot) correlations <- cor(data[,6:43]) corrplot(correlations, method="ellipse", order="hclust")
Before going further, I’d like to emphasize that when the data is small, it’s very likely to find some seemingly good complicated interactions among features, for example, by looking at the p-values for a simple linear regression and by tweaking couple of significant ones, I found a (good-turn-out-to-be-useless) interaction between logRevenue~I(P6*P20)+P28+I(P28^2)+Age+I(Age^2)
, including Age of the restaurant, which at the time, I thought it might be a good feature to add, but the deserved winner argued the opposite!
Now, let’s dig into bartMachine. First, I’m going to throw out the incompatible features including Open.Date, City, City.Group, Type as well as Id to make our training and testing sets unified and initialize the bartMachine as follows;
SEED <- 123 set.seed(SEED) options(java.parameters="-Xmx5000m") # must be set initially library(bartMachine) set_bart_machine_num_cores(4) y <- log(revenue) X <- all[1:nrow(train), -c(1,2,3,4,5)] X_test <- all[(nrow(train)+1):nrow(all), -c(1,2,3,4,5)] bart <- bartMachine(X, y, num_trees=10, seed=SEED) # default num_trees might be excessive
The results are
bartMachine v1.2.0 for regression
training data n = 137 and p = 37
built in 0.3 secs on 4 cores, 10 trees, 250 burn-in and 1000 post. samplessigsq est for y beforehand: 0.159
avg sigsq estimate after burn-in: 0.19178in-sample statistics:
L1 = 41.5
L2 = 22.15
rmse = 0.4
Pseudo-Rsq = 0.2953
p-val for shapiro-wilk test of normality of residuals: 0.08099
p-val for zero-mean noise: 0.97545
We can also see the effect of changing the num_trees
parameter by
rmse_by_num_trees(bart, tree_list=c(seq(5, 50, by=5)), num_replicates=5)
So perhaps num_trees=20
works better with less out-of-sample error.
As, bartMachine is essentially doing MCMC and heavily using Gibbs sampler, checking convergence seems necessary (see the above paper for more details);
bart <- bartMachine(X, y, num_trees=20, seed=SEED) plot_convergence_diagnostics(bart)
We can also check mean-centeredness assumption and heteroskedasticity of the data by the QQ-plot
check_bart_error_assumptions(bart)
One of the main features of bartMachine is showing variable importance via permutation tests
var_selection_by_permute(bart, num_reps_for_avg=20)
Moreover, bartMachine is capable of finding explicit interactions among feature
interaction_investigate(bart, num_replicates_for_avg=5)
However, since we want a simple model to start and score well in the poisoned test set, focusing too much on interactions can hurt a lot, as happened to me by inevitable overfitting. Now, I’m going to pick P20, P17, P6 and retrain my model and perform a CV and finally submit it.
nX <- X[, c("P6", "P17", "P28")] nX_test <- X_test[, c("P6", "P17", "P28")] nbart <- bartMachine(nX, y, num_trees=20, seed=SEED) nbart_cv <- bartMachineCV(nX, y, num_tree_cvs=c(10,15,20), k_folds=5, verbose=TRUE) # bartMachine CV win: k: 3 nu, q: 3, 0.9 m: 20 log_pred <- predict(nbart_cv, nX_test) pred <- exp(log_pred) submit <- data.frame(Id=test$Id, Prediction = pred) write.csv(submit, file="bartMachine_P6P17P28.csv", row.names=F)
The result of this submission, is 1731437.14026 in public and 1839841.77388 (ranks around 700 out of 2257 teams) in private leaderboards.
Rank below 100 out of 2257:
To improve our model and make it more robust, we should remove outliers by deciding what threshold to use perhaps with some trial-and-error first, though there’re many methods that can be helpful, which is not within the scope of this post. Looking at the log(revenue)
distribution should suggest where you should put the threshold. Doing so, it can hugely boost your private leaderboard ranking to position around 80 with score 1786605.94364 which is a fairly good incentive to explore bartMachine more.
You can find the codes including the improved model and plots in my github. Enjoy!
P.S. During the competition, the BAYZ,M.D team managed to make a perfect overfit on the public leaderboard with interesting and creative methods. You can read more about that here.
Hi great reading yoour blog