library('ggplot2')
library('reshape2')
library('randomForest')
## randomForest 4.6-10
## Type rfNews() to see new features/changes/bug fixes.
library('glmnet')
## Loading required package: Matrix
## Loaded glmnet 1.9-8
library('kernlab')
library('ROCR')
## Loading required package: gplots
##
## Attaching package: 'gplots'
##
## The following object is masked from 'package:stats':
##
## lowess
library('plyr')
library('Hmisc')
## Loading required package: grid
## Loading required package: lattice
## Loading required package: survival
## Loading required package: splines
## Loading required package: Formula
##
## Attaching package: 'Hmisc'
##
## The following objects are masked from 'package:plyr':
##
## is.discrete, summarize
##
## The following object is masked from 'package:randomForest':
##
## combine
##
## The following objects are masked from 'package:base':
##
## format.pval, round.POSIXt, trunc.POSIXt, units
dTrain = read.table("isolet1+2+3+4.data.gz",
header=FALSE,sep=',',
stringsAsFactors=FALSE,blank.lines.skip=TRUE)
dTrain$isTest <- FALSE
dTest = read.table("isolet5.data.gz",
header=FALSE,sep=',',
stringsAsFactors=FALSE,blank.lines.skip=TRUE)
dTest$isTest <- TRUE
d <- rbind(dTrain,dTest)
rm(list=c('dTest','dTrain'))
d$V618 <- letters[d$V618]
vars <- colnames(d)[1:617]
yColumn <- 'isN'
d[,yColumn] <- d[,'V618']=='n'
# true prevalence of N
mean(d$isN)
## [1] 0.03847634
Functions
# draw a set of size N where target variable has prevalence <prevalence>
# assume target is a T/F variable
makePrevalence = function(dataf, target, prevalence, N) {
# indices of T/F
tset_ix = which(dataf[[target]])
others_ix = which(!dataf[[target]])
ntarget = round(N*prevalence)
heads = sample(tset_ix, size=ntarget, replace=TRUE)
tails = sample(others_ix, size=(N-ntarget), replace=TRUE)
dataf[c(heads, tails),]
}
# to make sure that both factors T and F are represented,
# even if the column is purely TRUE or purely FALSE
bool2factor = function(boolcol) {
bf = as.factor(c(T, F, boolcol))
bf[-(1:2)]
}
metrics = function(y, pred, prevalence, label, threshold=0.5) {
cmat = table(outcome=bool2factor(y),
prediction=bool2factor(pred>threshold))
accuracy = sum(diag(cmat))/sum(cmat)
precision = cmat[2,2]/sum(cmat[,2])
recall = cmat[2,2]/sum(cmat[2,]) # also sensitivity
specificity = cmat[1,1]/sum(cmat[1,]) # 1-FPR
data.frame(prevalence=prevalence, accuracy=accuracy,
precision=precision,
recall=recall,
specificity=specificity,
label=label)
}
# posit a test with a given sensitivity/specificity
# and populations with different prevalences of diabetes
scenario = function(sens, spec, prevalence, trainPrevalence) {
npos = prevalence
nneg = 1-prevalence
tPos = npos*sens # true positives
tNeg = nneg*spec # true negatives
fPos = nneg - tNeg # false positives (negatives mis-diagnosed)
fNeg = npos - tPos # false negatives (positives mis-diagnosed)
# print(paste("accuracy = ", (tPos + tNeg))) # what we got right
# print(paste("precision = ", (tPos/(tPos + fPos)))) # the number predicted positive that really are
data.frame(train_prevalence=trainPrevalence,
th_accuracy=(tPos+tNeg),
th_precision=(tPos/(tPos+fPos)),
th_sens=sens, th_spec=spec)
}
fit_logit = function(data, vars, yColumn) {
# note: as.matrix should only be called if all the
# vars are numerical; otherwise use model.matrix
model = cv.glmnet(x=as.matrix(data[,vars]),y=data[[yColumn]],
alpha=0,
family='binomial')
function(d) {
predict(model,newx=as.matrix(d[,vars]),type='response')[,1]
}
}
fit_rf = function(data, vars, yColumn) {
model = randomForest(x=data[,vars],y=as.factor(data[[yColumn]]))
function(d) {
predict(model,newdata=d,type='prob')[,'TRUE',drop=TRUE]
}
}
fit_svm = function(data, vars, yColumn) {
formula = paste(yColumn, "~", paste(vars, collapse="+"))
model = ksvm(as.formula(formula),data=data,type='C-svc',C=1.0,
kernel='rbfdot',
prob.model=TRUE)
function(d) {
predict(model,newdata=d,type='prob')[,'TRUE',drop=TRUE]
}
}
#
# stats is a dataframe whose rows are the output of metrics()
#
statsPlot = function(stats, metric, baseprev) {
baseline = function(column) {
fmla = as.formula(paste(column, "~ label"))
aggregate(fmla, data=subset(stats, stats$prevalence==baseprev), FUN=mean)
}
ggplot(stats, aes_string(x="prevalence", y=metric, color="label")) +
geom_point(alpha=0.2) +
stat_summary(fun.y="mean", geom="line") +
stat_summary(fun.data="mean_cl_boot", geom="errorbar") +
geom_hline(data=baseline(metric), aes_string(yintercept=metric, color="label"), linetype=2) +
scale_x_continuous("training prevalence") +
ggtitle(metric) + facet_wrap(~label, ncol=1)
}
Do the run
Setup.
# set the random seed
set.seed(5246253)
# use the test set at the base prevalence
test = d[d$isTest,]
basePrev = mean(test[[yColumn]])
trainall = d[!d$isTest,]
# the training set will have various enriched preferences
prevalences = c(c(1, 2, 5, 10)*basePrev, 0.5)
N=2000
# reaches into global environment
makeModelsAndScore = function(prev, verbose=F) {
train = makePrevalence(trainall, yColumn, prev, N)
models = list(fit_logit(train, vars, yColumn),
fit_rf(train, vars, yColumn),
fit_svm(train, vars, yColumn))
names(models) = c("logistic",
"random forest",
"svm")
nmodels = length(models)
if(verbose) {
densityPlot = function(predcolumn, dataf, title) {
ggplot(cbind(dataf, pred=predcolumn), aes_string(x="pred", color=yColumn)) +
geom_density() + ggtitle(title)
}
# training prevalences
print("Metrics on training data")
for(i in seq_len(nmodels)) {
print(densityPlot(models[[i]](train), train, paste(names(models)[i], ": training")))
}
f <- function(i) {metrics(train[[yColumn]],
models[[i]](train),
prev,
names(models)[i])}
metricf = ldply(seq_len(nmodels), .fun=f)
print(metricf)
}
f <- function(i) {metrics(test[[yColumn]],
models[[i]](test),
prev,
names(models)[i])}
metricf = ldply(seq_len(nmodels), .fun=f)
if(verbose) {
# test prevalences
print("Metrics on test data")
print(metricf)
for(i in seq_len(nmodels)) {
print(densityPlot(models[[i]](test), test, paste(names(models)[i], ": test")))
}
}
metricf
}
Run.
makeModelsAndScore(basePrev, verbose=T)
## Using automatic sigma estimation (sigest) for RBF or laplace kernel
## [1] "Metrics on training data"
## prevalence accuracy precision recall specificity label
## 1 0.03848621 0.9985 1.0000000 0.961039 1.00000 logistic
## 2 0.03848621 1.0000 1.0000000 1.000000 1.00000 random forest
## 3 0.03848621 0.9975 0.9736842 0.961039 0.99896 svm
## [1] "Metrics on test data"
## prevalence accuracy precision recall specificity label
## 1 0.03848621 0.9807569 0.7777778 0.7000000 0.9919947 logistic
## 2 0.03848621 0.9717768 1.0000000 0.2666667 1.0000000 random forest
## 3 0.03848621 0.9846055 0.7903226 0.8166667 0.9913276 svm
## prevalence accuracy precision recall specificity label
## 1 0.03848621 0.9807569 0.7777778 0.7000000 0.9919947 logistic
## 2 0.03848621 0.9717768 1.0000000 0.2666667 1.0000000 random forest
## 3 0.03848621 0.9846055 0.7903226 0.8166667 0.9913276 svm
# stats = ldply(prevalences, .fun=makeModelsAndScore)
stats = ldply(1:10, .fun=function(x){ldply(prevalences, .fun=makeModelsAndScore)})
## Using automatic sigma estimation (sigest) for RBF or laplace kernel
## Using automatic sigma estimation (sigest) for RBF or laplace kernel
## Using automatic sigma estimation (sigest) for RBF or laplace kernel
## Using automatic sigma estimation (sigest) for RBF or laplace kernel
## Using automatic sigma estimation (sigest) for RBF or laplace kernel
## Using automatic sigma estimation (sigest) for RBF or laplace kernel
## Using automatic sigma estimation (sigest) for RBF or laplace kernel
## Using automatic sigma estimation (sigest) for RBF or laplace kernel
## Using automatic sigma estimation (sigest) for RBF or laplace kernel
## Using automatic sigma estimation (sigest) for RBF or laplace kernel
## Using automatic sigma estimation (sigest) for RBF or laplace kernel
## Using automatic sigma estimation (sigest) for RBF or laplace kernel
## Using automatic sigma estimation (sigest) for RBF or laplace kernel
## Using automatic sigma estimation (sigest) for RBF or laplace kernel
## Using automatic sigma estimation (sigest) for RBF or laplace kernel
## Using automatic sigma estimation (sigest) for RBF or laplace kernel
## Using automatic sigma estimation (sigest) for RBF or laplace kernel
## Using automatic sigma estimation (sigest) for RBF or laplace kernel
## Using automatic sigma estimation (sigest) for RBF or laplace kernel
## Using automatic sigma estimation (sigest) for RBF or laplace kernel
## Using automatic sigma estimation (sigest) for RBF or laplace kernel
## Using automatic sigma estimation (sigest) for RBF or laplace kernel
## Using automatic sigma estimation (sigest) for RBF or laplace kernel
## Using automatic sigma estimation (sigest) for RBF or laplace kernel
## Using automatic sigma estimation (sigest) for RBF or laplace kernel
## Using automatic sigma estimation (sigest) for RBF or laplace kernel
## Using automatic sigma estimation (sigest) for RBF or laplace kernel
## Using automatic sigma estimation (sigest) for RBF or laplace kernel
## Using automatic sigma estimation (sigest) for RBF or laplace kernel
## Using automatic sigma estimation (sigest) for RBF or laplace kernel
## Using automatic sigma estimation (sigest) for RBF or laplace kernel
## Using automatic sigma estimation (sigest) for RBF or laplace kernel
## Using automatic sigma estimation (sigest) for RBF or laplace kernel
## Using automatic sigma estimation (sigest) for RBF or laplace kernel
## Using automatic sigma estimation (sigest) for RBF or laplace kernel
## Using automatic sigma estimation (sigest) for RBF or laplace kernel
## Using automatic sigma estimation (sigest) for RBF or laplace kernel
## Using automatic sigma estimation (sigest) for RBF or laplace kernel
## Using automatic sigma estimation (sigest) for RBF or laplace kernel
## Using automatic sigma estimation (sigest) for RBF or laplace kernel
## Using automatic sigma estimation (sigest) for RBF or laplace kernel
## Using automatic sigma estimation (sigest) for RBF or laplace kernel
## Using automatic sigma estimation (sigest) for RBF or laplace kernel
## Using automatic sigma estimation (sigest) for RBF or laplace kernel
## Using automatic sigma estimation (sigest) for RBF or laplace kernel
## Using automatic sigma estimation (sigest) for RBF or laplace kernel
## Using automatic sigma estimation (sigest) for RBF or laplace kernel
## Using automatic sigma estimation (sigest) for RBF or laplace kernel
## Using automatic sigma estimation (sigest) for RBF or laplace kernel
## Using automatic sigma estimation (sigest) for RBF or laplace kernel
Plot.
statsPlot(stats, "accuracy", basePrev)
statsPlot(stats, "precision", basePrev)
statsPlot(stats, "recall", basePrev)
statsPlot(stats, "specificity", basePrev)
One last experiment. What about just sweeping the threshold?
train = makePrevalence(trainall, yColumn, basePrev, N)
models = list(fit_logit(train, vars, yColumn),
fit_rf(train, vars, yColumn),
fit_svm(train, vars, yColumn))
## Using automatic sigma estimation (sigest) for RBF or laplace kernel
names(models) = c("logistic",
"random forest",
"svm")
nmodels = length(models)
## ---------- functions --------------
# reaches into global environment for test set and ycolumn
# model is the output of one of the fit_xxx functions:
# a function to make predictions from the trained model
sweepThresholds = function(model, label) {
thresholds = seq(from=0.25, to=0.75, by=0.05)
f = function(thresh) {
metrics(test[[yColumn]],
model(test),
thresh,
label,
threshold=thresh)
}
ldply(thresholds, .fun=f)
}
threshPlot = function(stats, metric) {
ggplot(stats, aes_string(x="prevalence", y=metric, color="label")) +
geom_point() +
stat_summary(fun.y="mean", geom="line") +
stat_summary(fun.data="mean_cl_boot", geom="errorbar") +
scale_x_continuous("threshold") +
ggtitle(metric) + facet_wrap(~label, ncol=1)
}
## ---------- end functions --------------
stats = ldply(seq_len(nmodels), .fun=function(i) {sweepThresholds(models[[i]], names(models)[i])})
threshPlot(stats, "accuracy")
threshPlot(stats, "precision")
## Warning: Removed 2 rows containing missing values (stat_summary).
## Warning: Removed 2 rows containing missing values (stat_summary).
## Warning: Removed 2 rows containing missing values (geom_point).
threshPlot(stats, "recall")
threshPlot(stats, "specificity")