library('randomForest')
## randomForest 4.6-10
## Type rfNews() to see new features/changes/bug fixes.
library('ggplot2')
rfFitter <- function(vars,yTarget,data) {
model <- randomForest(x=data[,vars,drop=FALSE],
y=as.factor(as.character(data[,yTarget,drop=TRUE])),
ntree=100,
maxnodes=10)
function(newd) {
predict(model,newdata=newd,type='prob')[,'TRUE']
}
}
logisticFitter <- function(vars,yTarget,data) {
formula <- paste(yTarget,
paste(vars,collapse=' + '),sep=' ~ ')
model <- glm(as.formula(formula),data,
family=binomial(link='logit'))
function(newd) {
predict(model,newdata=newd,type='response')
}
}
fitters <- list(
RandomForest=rfFitter,
Logistic=logisticFitter
)
# data examples
vars <- c('x1','x2','x3')
# build an example classification data frame y~TRUE/FALSE approximate function
# of numeric vars. n-rows long.
datn <- function(vars,n) {
d <- as.data.frame(matrix(data=0,nrow=n,ncol=length(vars)))
names(d) <- vars
for(vi in vars) {
d[,vi] <- runif(n)
}
d$y <- d$x1+d$x2+d$x3>=1.5+rnorm(nrow(d))
d
}
dTrain <- datn(vars,1000)
dTest <- datn(vars,100)
rfFitter(vars,'y',dTrain)(dTest)
## 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
## 0.91 0.57 0.78 0.10 0.60 0.21 0.82 0.94 0.28 0.24 0.15 0.33 0.83 0.95 0.77
## 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30
## 0.21 0.33 0.24 0.14 0.35 0.22 0.08 0.07 0.26 0.74 0.12 0.26 0.18 0.46 0.88
## 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45
## 0.52 0.49 0.03 0.76 0.68 0.83 0.79 0.81 0.37 0.46 0.83 0.22 0.16 0.50 0.38
## 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60
## 0.79 0.56 0.83 0.57 0.66 0.16 0.86 0.04 0.35 0.81 0.50 0.20 0.50 0.07 0.79
## 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75
## 0.12 0.78 0.81 0.24 0.22 0.94 0.35 0.58 0.94 0.65 0.92 0.91 0.69 0.62 0.03
## 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90
## 0.73 0.17 0.71 0.03 0.70 0.06 0.65 0.73 0.07 0.40 0.35 0.01 0.85 0.55 0.51
## 91 92 93 94 95 96 97 98 99 100
## 0.16 0.73 0.16 0.13 0.22 0.47 0.59 0.13 0.41 0.15
logisticFitter(vars,'y',dTrain)(dTest)
## 1 2 3 4 5 6
## 0.71489607 0.54382168 0.68934741 0.31058741 0.68455720 0.38486226
## 7 8 9 10 11 12
## 0.83116437 0.78100812 0.34534941 0.36823573 0.37731521 0.39199136
## 13 14 15 16 17 18
## 0.72951814 0.80196398 0.66419570 0.34461598 0.33363645 0.25527952
## 19 20 21 22 23 24
## 0.28468355 0.45096656 0.29234710 0.18515067 0.17285655 0.48149062
## 25 26 27 28 29 30
## 0.67143961 0.31878631 0.26236357 0.23225823 0.43908601 0.70913211
## 31 32 33 34 35 36
## 0.52423549 0.44784422 0.16744764 0.73228709 0.72451108 0.74139920
## 37 38 39 40 41 42
## 0.75489516 0.69953952 0.44875397 0.63081325 0.72045428 0.31659571
## 43 44 45 46 47 48
## 0.30011420 0.59767712 0.45007996 0.77233364 0.51741790 0.77982798
## 49 50 51 52 53 54
## 0.56169887 0.66668898 0.41140132 0.75435046 0.18384538 0.33169753
## 55 56 57 58 59 60
## 0.68889220 0.49901793 0.33394117 0.49314214 0.29088447 0.75715052
## 61 62 63 64 65 66
## 0.25150118 0.76101500 0.65950332 0.37599035 0.36464796 0.79525592
## 67 68 69 70 71 72
## 0.49321833 0.48296309 0.76550757 0.59192239 0.81997397 0.75156044
## 73 74 75 76 77 78
## 0.71057616 0.51665233 0.19093330 0.54992062 0.31683260 0.51088071
## 79 80 81 82 83 84
## 0.22048398 0.69998826 0.07067548 0.61994142 0.73065835 0.29893148
## 85 86 87 88 89 90
## 0.34836721 0.29375109 0.13859649 0.73631893 0.55944633 0.58049130
## 91 92 93 94 95 96
## 0.25160779 0.59445440 0.24687087 0.29432841 0.37391306 0.61377199
## 97 98 99 100
## 0.50173639 0.30677075 0.58844046 0.27199578
runExperiment <- function(fitterName) {
print(fitterName)
fitter <- fitters[[fitterName]]
predictor <- fitter(vars,'y',dTrain)
predictions <- predictor(dTest)
data.frame(fitter=fitterName,truth=dTest[['y']],prediction=predictions)
}
results <- do.call('rbind',lapply(names(fitters),runExperiment))
## [1] "RandomForest"
## [1] "Logistic"
ggplot(data=results) +
geom_density(aes(x=prediction,color=truth)) +
facet_wrap(~fitter,ncol=1)